Когда мы обучаем модели машинного обучения, почти всегда возникает один и тот же вопрос:
Обычно мы смотрим на графики метрик и пытаемся вручную интерпретировать происходящее:
Модель недообучена
Модель переобучена
Имбаланс датасета.
Сильно шумные данные.
Можно посмотреть на learning curves и понять, что происходит:
Но этот анализ почти всегда выполняется вручную или с помощью простейших эвристических правил. А ведь сколько времени, сил и нервов можно было бы сэкономить, если обучить до 100 эпохи а не до 500 (см картинка выше) :-(
Но можно задать интересный вопрос:
А можно ли автоматически определить состояние обучения модели?
А что если научить отдельную модель, которая будет автоматически определять состояние обучения?
То есть вместо ручного анализа мы обучаем модель, которая делает это автоматически. Но насколько это эффективно, на проде, интерпретируемы ли эти результаты для разных типов задач и т. д.
Чтобы обучить такой классификатор, нужен датасет с различными сценариями обучения.
Я решил сгенерировать его программно.
В качестве базового датасета использовался MNIST — классический набор изображений рукописных цифр.
Эксперименты проводились с несколькими типами моделей:
logistic regression
небольшой MLP
большой MLP
маленькая CNN
большая CNN
Для каждого эксперимента варьировались параметры:
размер обучающей выборки
случайный seed
наличие дисбаланса классов
тип сдвига данных
По итогу я обучил 270 моделей и посмотрел их после 1, 5, 6, 11,16,21,26 эпох. По каждой записи были сохранены:
|
Столбец |
Тип |
Описание |
|
model |
str |
Название модели, использованной для обучения (logreg, mlp_small, mlp_large, cnn_small, cnn_large). |
|
train_size |
int |
Размер выборки для обучения в конкретном эксперименте. |
|
seed |
int |
Значение random seed для воспроизводимости случайной выборки. |
|
imbalance |
bool |
Флаг, указывающий, использовался ли искусственный дисбаланс классов (True) или нет (False). |
|
shift_type |
str |
Тип сдвига данных на тестовой выборке (none, noise, invert). |
|
train_acc |
float |
Точность модели на тренировочной выборке после текущей эпохи. |
|
val_acc |
float |
Точность модели на валидационной выборке после текущей эпохи. |
|
test_acc |
float |
Точность модели на тестовой выборке (с учетом возможного сдвига данных). |
|
gap |
float |
Разница между тренировочной и валидационной точностью (train_acc - val_acc). Используется для диагностики переобучения. |
|
epochs |
int |
Количество эпох обучения (для функции train_and_evaluate) — либо номер эпохи в train_with_history. |
|
val_curve |
list of list |
История точности на валидационной выборке по эпохам до текущей. |
|
epoch |
int |
Номер текущей эпохи обучения (используется при пошаговом train_with_history). |
|
underfitting |
int (0/1) |
Диагностический флаг: модель недообучена, если |
|
overfitting |
int (0/1) |
Диагностический флаг: модель переобучена, если |
|
dataset_shift |
int (0/1) |
Диагностический флаг: есть смещение тестовых данных, если |
С мериками получилось сложно, нельзя точно сказать, что при val_acс 0.9 нет переобучения, однако, в рамках работы я просто тестил всё на test_dataset и ставил метки по нему. правила для меток:
def diagnose(metrics): return { "underfitting": int(metrics["train_acc"] < 0.7), "overfitting": int(metrics["gap"] > 0.15), "dataset_shift": int(metrics["val_acc"] - metrics["test_acc"] > 0.15) }
В итоге в датасете я получил:
Касаемо качества датасета, меня устаивает, есть как и ужасные модели, так и неплохие, acc достиг 0.9.
Одним из самых интересных источников информации является форма learning curve. Я вытащил из него много признаков, все признаки на которых я делал метрики (подразумеваются как недоступные я удалил из обучения)
df["curve_start"] = df["val_curve"].apply(lambda x: x[0]) df["curve_mid"] = df["val_curve"].apply(lambda x: x[len(x)//2]) df["curve_end"] = df["val_curve"].apply(lambda x: x[-1]) df["curve_growth"] = df["curve_end"] - df["curve_start"] df["curve_stability"] = df["val_curve"].apply(np.std)
Для классификации были протестированы несколько алгоритмов:
Random Forest
XGBoost
Logistic Regression
ансамбль моделей
Поскольку задача имеет несколько независимых меток, использовался MultiOutputClassifier.
rf = RandomForestClassifier( n_estimators=200, random_state=42 ) model = MultiOutputClassifier(rf) model.fit(X_train, y_train) pred = model.predict(X_test)
Итоги после обучения:
precision recall f1-score support 0 0.94 0.89 0.91 177 1 0.96 0.97 0.96 593 2 0.97 0.88 0.92 233 3 0.75 0.73 0.74 419 micro avg 0.90 0.87 0.89 1422 macro avg 0.90 0.87 0.88 1422 weighted avg 0.90 0.87 0.88 1422 samples avg 0.86 0.84 0.83 1422
Лучшие результаты показал Random Forest.
Он хорошо определял:
underfitting
dataset shift
Логистическая регрессия показала более низкое качество — что ожидаемо, так как она является линейным классификатором. Ансамбль моделей практически не улучшил результат.
Этот подход можно использовать в ML-pipeline.
Это может позволить:
автоматически выявлять переобучение
обнаруживать проблемы с данными
останавливать обучение раньше
экономить вычислительные ресурсы
На этом всё, если кому-то интересно потрогать руками напишите в комментах, на гитхабе надо убраться, буду рад критике, есть что добавить и так, планирую дописать 2 часть. Вот мой гитхаб, в целом там иногда есть что-то интересное.
Спасибо, всем хорошего дня.
Источник


