Keras, графики потерь и точности при обучении и валидации модели

25.01.2020

Ниже представлен код функции, которая создает картинку с графиками потерь и точности на этапах обучения и валидации модели нейронной сети.

Данные о потерях и точности берутся из объекта History, который возвращает метод fit в Keras.

import matplotlib.pyplot as plt

...

def create_train_charts(filename, history):
    """
    Создает графики потерь и точности при обучении и валидации модели, построенной на базе Keras с обучением и валидацией с помощью метода fit.
    Результат сохраняет в файл.

    Предполагается, что:
        1. метрика точности имела название 'accuracy'.

    Параметры:
        history:    атрибут history объекта History, который возвращает метод fit из Keras;
        filename:   путь к файлу, в который нужно сохранить график.
    """

    # Список с данными для построения графиков.
    data = []
    # Данные точности.
    data.append({
        'label': 'Точность при обучении',
        'title': 'Точность',
        'val_label': 'Точность при валидации',
        'val_values': history.get('val_accuracy'),
        'values': history['accuracy'],
    })
    # Данные потерь.
    data.append({
        'label': 'Потери при обучении',
        'title': 'Потери',
        'val_label': 'Потери при валидации',
        'val_values': history.get('val_loss'),
        'values': history['loss'],
    })

    # Определение количества эпох обучения для краткости.
    epochs = range(1, len(data[0]['values'])+1)

    # Создание основной фигуры и нужного количества холстов для графиков.
    figure, axes = plt.subplots(len(data), 1, figsize=(7, 10))
    # Корректировка расстояния между разными графиками.
    plt.subplots_adjust(hspace=.4)

    # Перебор данных для разных графиков.
    for i, axis in enumerate(axes):
        # Сетка под графиком.
        axis.grid(b=True, color='lightgray', which='both', zorder=0)
        # График основных данных в виде зеленых точек.
        axis.plot(
            epochs,
            data[i]['values'],
            '.',
            label=data[i]['label'],
            color='g',
            zorder=3
        )
        # Если данные по валидации есть...
        if data[i]['val_values']:
            # ...рисуется их график в виде красной линии.
            axis.plot(
                epochs,
                data[i]['val_values'],
                label=data[i]['val_label'],
                color='r',
                zorder=3
            )
        # Общий заголовок графика.
        axis.set_title(data[i]['title'])
        # Подпись оси Х.
        axis.set_xlabel('Эпохи')
        # Подпись оси Y.
        axis.set_ylabel(data[i]['title'])
        # Отображение легенды.
        axis.legend()

    # Сохранение графика в файл.
    figure.savefig(filename)

    plt.close()