TensorFlow 中的回调函数

回调函数是 TensorFlow 训练之中非常重要的一部分,我们在之前的学习之中或多或少地用到了回调函数。比如在之前的过拟合一节之中,我们就曾经用到了早停回调。那么这节课我们就来学习以下 TensorFlow 之中的回调函数。

1. 什么是回调函数

简单来说,回调函数就是在训练到一定阶段的时候而执行的函数,我们最常采用的策略是每个Epoch结束之后执行一次回调函数

回调函数的绝大多数 API 集中在 tf.keras.callbacks 之中,也就是说这是 Keras 之中的一个 API 。由于之前已经学习过早停回调,这节课我们来学习一下其他的几个常用的回调:

  • 模型保存回调:tf.keras.callbacks.ModelCheckpoint;
  • 学习率回调;tf.keras.callbacks.LearningRateScheduler;
  • 自定义回调:tf.keras.callbacks.CallBack。

对于回调的使用方法,也是非常简单的,假设以下的数组之中定义了我们所需要的全部回调函数:

callbacks = [......]

那么我们在使用回调的时候,之中只需要在训练函数中指定回调即可:

model.fit(..., ..., callbacks=callbacks)

对于要介绍的回调,我们会首先给出介绍,然后再在统一的代码之中示例使用。

2. 模型保存回调

模型保存的回调函数为:

tf.keras.callbacks.ModelCheckpoint(
    path, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, save_freq='epoch')

这里只列出来了我们常用的参数,对于其中的每个参数,它们的作用如下:

  • path: 保存模型的路径;
  • monitor: 用哪个指标来评价模型的好坏,默认是验证集上的损失;
  • verbose: 输出日志的等级,只能为 0 或 1;
  • save_best_only: 是否只保存最好的模型,模型的好坏由 monitor 指定;
  • save_weights_only: 是否只保存权重,默认 False ,也就是保存整个模型;
  • save_freq: 保存的频率,可以为 ‘Epoch’ 或者一个整数,默认为每个 Epoch 保存一次模型;若是一个整数N,则是每训练 N 个 Batch 保存一次模型。

3. 学习率回调

学习率回调函数为:

tf.keras.callbacks.LearningRateScheduler(
    schedule, verbose=0
)

其中 verbose 参数仍然是日志输出的等级,默认为 0 ;而 schedule 则是一个函数,用来定义一个学习率的变化。其中 schedule 函数的一个示例如下所示:

def my_schedule(epoch, lr):
  if epoch < 20:
    return lr
  else:
    return lr * 0.1

该学习率回调是在 20 个 Epoch 之前学习率保持不变,而在 20 个 Epoch 之后,每个 Epoch 学习率变为原来的 0.1 。

可以看出,该 schedule 函数由严格的形式,其中第一个参数为训练的 Epoch ,第二个参数为当前的学习率。

4. 自定义回调

我们在使用回调的过程之中难免会遇到要自定义回调的情况,这时我们便需要编写类来继承 tf.keras.callbacks.CallBack 类,从而实现我们的自定义回调

在自定义回调的过程之中,你可以覆写不同的函数,从而可以实现在不同的时间来运行我们自定义的函数,这些函数包括:

  • on_train_begin(self, logs=None): 在训练开始时调用;
  • on_test_begin(self, logs=None): 在测试开始时调用;
  • on_predict_begin(self, logs=None): 在预测开始时调用;
  • on_train_end(self, logs=None) 在训练结束时调用;
  • on_test_end(self, logs=None) 在测试结束时调用;
  • on_predict_end(self, logs=None) 在预测结束时调用;
  • on_train_batch_begin(self, batch, logs=None) 在训练期间的每个批次之前调用;
  • on_test_batch_begin(self, batch, logs=None) 在测试期间的每个批次之前调用;
  • on_predict_batch_begin(self, batch, logs=None) 在预测期间的每个批次之前调用;
  • on_train_batch_end(self, batch, logs=None) 在训练期间的每个批次之后调用;
  • on_test_batch_end(self, batch, logs=None) 在测试期间的每个批次之后调用;
  • on_predict_batch_end(self, batch, logs=None) 在预测期间的每个批次之后调用;
  • on_epoch_begin(self, epoch, logs=None) 在每次迭代训练开始时调用;
  • on_epoch_end(self, epoch, logs=None) 在每次迭代训练结束时调用。

我们可以来使用其中两个简单的函数来做一个简单的示例:

class MyCallback(tf.keras.callbacks.Callback):

    def on_epoch_begin(self, epoch, logs=None):
        print("Start epoch {}.".format(epoch))

    def on_train_begin(self, logs=None):
        print("Starting training.")

这个样子,我们便可以在每次训练开始,以及每个 Epoch 开始之时进行输出日志。

5. 程序示例

在这里,我们将同时使用模型保存回调、学习率回调以及自定义回调来做一个简单的示例:

import tensorflow as tf

model = tf.keras.Sequential([
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
])

lr = 0.01

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
    loss="mse"
)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()


def my_schedule(epoch, lr):
  print('Learning rate: ' + str(lr))
  if epoch < 5:
    return lr
  else:
    return lr * 0.1

lr_callback = tf.keras.callbacks.LearningRateScheduler(my_schedule)

save_model_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='/model/', save_weights_only=True, verbose=1,
    monitor='val_loss', mode='min', save_best_only=True)

class MyCallback(tf.keras.callbacks.Callback):

    def on_epoch_begin(self, epoch, logs=None):
        print("Start epoch {}.".format(epoch))

    def on_train_begin(self, logs=None):
        print("Starting training.")

model.fit(x_train, y_train,
    batch_size=64, epochs=10,
    validation_data=(x_test, y_test),
    callbacks=[MyCallback(), lr_callback, save_model_callback],
)

在这里,我们按照之前学习的方法定义了三个回调函数,分别是模型保存回调、学习率回调、以及自定义回调。其中模型保存回调会在每次训练后保存模型、学习率回调会在第五个 Epoch 之后便每个 Epoch 变为原来的 0.1 ,而自定义回调会在训练开始之前、每个 Epoch 开始之前输出相应的信息。

于是我们可以得到输出:

Starting training.
Start epoch 0.
Learning rate: 0.009999999776482582
Epoch 1/10
931/938 [============================>.] - ETA: 0s - loss: 556.1402
Epoch 00001: val_loss improved from inf to 15.96259, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 552.3954 - val_loss: 15.9626
Start epoch 1.
Learning rate: 0.009999999776482582
Epoch 2/10
927/938 [============================>.] - ETA: 0s - loss: 12.4227
Epoch 00002: val_loss improved from 15.96259 to 10.01533, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 12.3927 - val_loss: 10.0153
Start epoch 2.
Learning rate: 0.009999999776482582
Epoch 3/10
914/938 [============================>.] - ETA: 0s - loss: 9.0919
Epoch 00003: val_loss improved from 10.01533 to 8.50834, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 9.0744 - val_loss: 8.5083
Start epoch 3.
Learning rate: 0.009999999776482582
Epoch 4/10
913/938 [============================>.] - ETA: 0s - loss: 8.3514
Epoch 00004: val_loss improved from 8.50834 to 8.26637, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.3450 - val_loss: 8.2664
Start epoch 4.
Learning rate: 0.009999999776482582
Epoch 5/10
920/938 [============================>.] - ETA: 0s - loss: 8.2481
Epoch 00005: val_loss improved from 8.26637 to 8.25048, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2544 - val_loss: 8.2505
Start epoch 5.
Learning rate: 0.009999999776482582
Epoch 6/10
933/938 [============================>.] - ETA: 0s - loss: 8.2504
Epoch 00006: val_loss improved from 8.25048 to 8.25035, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2502 - val_loss: 8.2504
Start epoch 6.
Learning rate: 0.0009999999310821295
Epoch 7/10
932/938 [============================>.] - ETA: 0s - loss: 8.2509
Epoch 00007: val_loss improved from 8.25035 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 7.
Learning rate: 9.99999901978299e-05
Epoch 8/10
916/938 [============================>.] - ETA: 0s - loss: 8.2600
Epoch 00008: val_loss improved from 8.25034 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 8.
Learning rate: 9.99999883788405e-06
Epoch 9/10
914/938 [============================>.] - ETA: 0s - loss: 8.2541
Epoch 00009: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 9.
Learning rate: 9.99999883788405e-07
Epoch 10/10
925/938 [============================>.] - ETA: 0s - loss: 8.2446
Epoch 00010: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
<tensorflow.python.keras.callbacks.History at 0x7eff7317f748>

可以看到,我们的三个回调函数都能正确地输出相应的信息,说明我们的回调函数已经成功生效。

6. 小结

在这节课之中,我们学习了什么是回调函数、模型保存回调、学习率回调以及如何自定义回调。同时我们又通过相应的示例演示了如何使用回调。