Keras, графики потерь и точности при обучении и валидации модели
Ниже представлен код функции, которая создает картинку с графиками потерь и точности на этапах обучения и валидации модели нейронной сети.
Данные о потерях и точности берутся из объекта History
, который возвращает метод fit
в Keras.
import matplotlib.pyplot as plt
...
def create_train_charts(filename, history):
"""
Создает графики потерь и точности при обучении и валидации модели, построенной на базе Keras с обучением и валидацией с помощью метода fit.
Результат сохраняет в файл.
Предполагается, что:
1. метрика точности имела название 'accuracy'.
Параметры:
history: атрибут history объекта History, который возвращает метод fit из Keras;
filename: путь к файлу, в который нужно сохранить график.
"""
# Список с данными для построения графиков.
data = []
# Данные точности.
data.append({
'label': 'Точность при обучении',
'title': 'Точность',
'val_label': 'Точность при валидации',
'val_values': history.get('val_accuracy'),
'values': history['accuracy'],
})
# Данные потерь.
data.append({
'label': 'Потери при обучении',
'title': 'Потери',
'val_label': 'Потери при валидации',
'val_values': history.get('val_loss'),
'values': history['loss'],
})
# Определение количества эпох обучения для краткости.
epochs = range(1, len(data[0]['values'])+1)
# Создание основной фигуры и нужного количества холстов для графиков.
figure, axes = plt.subplots(len(data), 1, figsize=(7, 10))
# Корректировка расстояния между разными графиками.
plt.subplots_adjust(hspace=.4)
# Перебор данных для разных графиков.
for i, axis in enumerate(axes):
# Сетка под графиком.
axis.grid(b=True, color='lightgray', which='both', zorder=0)
# График основных данных в виде зеленых точек.
axis.plot(
epochs,
data[i]['values'],
'.',
label=data[i]['label'],
color='g',
zorder=3
)
# Если данные по валидации есть...
if data[i]['val_values']:
# ...рисуется их график в виде красной линии.
axis.plot(
epochs,
data[i]['val_values'],
label=data[i]['val_label'],
color='r',
zorder=3
)
# Общий заголовок графика.
axis.set_title(data[i]['title'])
# Подпись оси Х.
axis.set_xlabel('Эпохи')
# Подпись оси Y.
axis.set_ylabel(data[i]['title'])
# Отображение легенды.
axis.legend()
# Сохранение графика в файл.
figure.savefig(filename)
plt.close()