原文作者: 章立 本站授权首发

checkpoint 主要的目的有两个:

  1. 如果训练过程中出现的意外情况,可以通过checkpoint快速恢复
  2. 通过checkpoint可以 stop early,这样使得算法效果更好

keras

在keras中使用 Model.save_weights 方法来生成checkpoint.

但是如果使用这个方法的话,Model的layer必须分配给一个成员变量,特别是在构造器中?

并且如果是 Model.save_weights 方法生成的checkpoint,需要使用 Model.load_weights 来加载,不能使用 tf.train.Checkpoint.restore 来进行加载。

API文档中建议使用 tf.train.Checkpoint 来生成checkpoint

tf.train.Checkpoint

这个包是tensorflow中负责checkpoint的全生命周期的管理,包括:

  1. 定义checkpoint生成策略
  2. 管理checkpoint的恢复
1
2
3
4
# import
from __future__ import absolute_import,division,print_function,\
unicode_literals
import tensorflow as tf
/Users/ki/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

定义checkpoint生成策略

在使用checkpoint之前,首先需要我们定义一个简单的网络与一个简单的输入,就像 quickstart2 中所介绍的构建方式一样

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Net(tf.keras.Model):
    """just a simple linear model"""

    def __init__(self):
        super(Net,self).__init__()
        self.l1 = tf.keras.layers.Dense(5)
    def call(self,x):
        return self.l1(x)

def toy_dataset():
    inputs = tf.range(10.)[:,None]
    labels = inputs * 5. + tf.range(5.)[None,:]
    return tf.data.Dataset.from_tensor_slices(
        dict(x=inputs,y=labels)).repeat(10).batch(2)

def train_step(net,example,optimizer):
    """train net on example using optimizer"""
    with tf.GradientTape() as tape:
        output = net(example['x'])
        loss = tf.reduce_mean(tf.abs(output - example['y']))
    variables = net.trainable_variables
    gradients = tape.gradient(loss,variables)
    optimizer.apply_gradients(zip(gradients,variables))
    return loss

现在我希望在这个网络的训练过程中生成checkpoint.我需要怎么做?

首先需要明确的是,对于tensorflow来说他的主要的对象都是类似于 tf.Variable,是一个拥有内部状态的一个对象,我们checkout的对象的状态,恢复的也是对象的状态,而不是恢复这个对象。

在以上的前提下存在3个问题:

  1. 如何让对象被checkpoint?
  2. 什么对象才能被checkpoint?
  3. 如何从checkpoint恢复到对象中?

tf.train.Checkpoint 是tf2.0新增的功能,在tf1.X中由train.Saver进行支持。不进行赘述。在tf2.0中,Checkpoint是基于python对象进行序列化。

tf.train.Checkpoint 类的构造器: __init__(**kwag),通过构造时传入你所希望checkout的对象,然后在后续checkout过程中就会将所传入的对象进行checkout。下面这段常规的代码可以回答这三个问题

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
    step = tf.Variable(1), #迭代轮数
    optimizer = opt, #优化器
    net = net #网络
)#ckpt相当于定义了一个容器,将构造器中传入的对象的格式固定下来,并链接到这个对象
manager = tf.train.CheckpointManager(
    checkpoint = ckpt, #checkpoint
    directory ='./tf_ckpts_test', #存储路径
    max_to_keep = 3, #最大checkpoint份数
    keep_checkpoint_every_n_hours = None, #checkpoint 时间间隔
    checkpoint_name = 'test' #checkpoint 文件名
)
for example in toy_dataset():
    ckpt.step.assign_add(1)
    ckpt.restore(manager.latest_checkpoint) #将最后一个checkout恢复
    loss = train_step(net,example,opt)
    print("loss {:1.2f}".format(loss.numpy()))
    manager.save() #根据将当前checkout容器中的对象状态保存
loss 2.08
loss 0.98
loss 2.09
loss 3.38
loss 4.45
loss 2.53
loss 1.36
loss 1.28
loss 2.03
loss 2.66
loss 2.78
loss 1.96
loss 1.24
loss 0.82
loss 0.78
loss 2.79
loss 2.07
loss 1.34
loss 0.85
loss 0.68
loss 2.61
loss 1.78
loss 1.08
loss 0.75
loss 1.13
loss 2.37
loss 1.46
loss 0.68
loss 0.55
loss 1.33
loss 2.20
loss 1.37
loss 0.59
loss 0.49
loss 1.13
loss 2.07
loss 1.30
loss 0.54
loss 0.42
loss 1.05
loss 1.90
loss 1.20
loss 0.59
loss 0.32
loss 0.83
loss 1.76
loss 1.16
loss 0.60
loss 0.26
loss 0.71

其中keep_checkpoint_every_n_hours参数保存的checkpoint并不会收到max_to_keep参数的限制,max_to_keep限制的是通过save函数主动checkpoint的数据

Q1 如何让对象被checkpoint?

通过将对象通过 Checkpoint 构造器传入,也可以通过 listed 或者 mapped 来传入list或者dictionary对象来灵活的构造Checkpoint

Q2 什么对象才能checkpoint?

TrackableBase 派生出来的对象 才能被checkout

1
2
3
4
5
6
7
testList =  []
ckpt_error = tf.train.Checkpoint(
    test = testList,
    step = tf.Variable(1), #记录迭代轮数
    optimizer = opt, #记录优化器状态
    net = net #记录网络状态
)
---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-4-ce172f4ab95a> in <module>()
      4     step = tf.Variable(1), #记录迭代轮数
      5     optimizer = opt, #记录优化器状态
----> 6     net = net #记录网络状态
      7 )

~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in __init__(self, **kwargs)
   1777              "object should be trackable (i.e. it is part of the "
   1778              "TensorFlow Python API and manages state), please open an issue.")
-> 1779             % (v,))
   1780       setattr(self, k, v)
   1781     self._save_counter = None  # Created lazily for restore-on-create.

ValueError: `Checkpoint` was expecting a trackable object (an object derived from `TrackableBase`), got []. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.

Q3 如何从checkpoint恢复到对象中?

定义 相同参数 的对象,然后将这些对象构造一个Checkout,然后调用restore方法,从指定的路径上恢复。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
#重复restore会产生警告
opt3 = tf.keras.optimizers.Adam(0.1)
net3 = Net()
step3 = tf.Variable(1)
print("restore前:{}".format(step3.numpy()))
ckpt3 = tf.train.Checkpoint(
    optimizer=opt3,
    step=step3
    ,net=net3)
file = tf.train.latest_checkpoint('./tf_ckpts_test')
print(file)
status = ckpt3.restore(file)
print("restore后:{}".format(step3.numpy()))
restore前:1
./tf_ckpts_test/test-100
restore后:51

Q3.1 为什么这里好像显示net3没有被加载成功呢?调用trainable_variables显示的结果不同

1
net3.trainable_variables
[]
1
net.trainable_variables
[<tf.Variable 'net/dense/kernel:0' shape=(1, 5) dtype=float32, numpy=
 array([[4.503945 , 4.5462313, 4.852895 , 4.7684402, 4.9965386]],
       dtype=float32)>,
 <tf.Variable 'net/dense/bias:0' shape=(5,) dtype=float32, numpy=
 array([3.375819 , 3.8449383, 2.7898471, 4.4520674, 4.1617193],
       dtype=float32)>]

原因在于restore是 延迟加载(Delayed restorations)。Layer对象会将其内部的Variable的创建延迟到其首次调用。

Estimator对Checkpoint的额外支持

Estimator有一个默认的CheckoutManager,只要你在model_fn内部构造了Checkpoint对象。那么在训练中就会保存下来每一轮的模型,但只保留最新的一份checkpoint。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
    net = Net() #定义模型
    opt = tf.keras.optimizers.Adam(0.1) #定义迭代器
    ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net) #定义Checkout
    #不需要定义CheckoutManager,因为Estimator自带默认的CheckoutManager
    #开始训练
    with tf.GradientTape() as tape:
        output = net(features['x'])
        loss = tf.reduce_mean(tf.abs(output - features['y']))
        variables = net.trainable_variables
        gradients = tape.gradient(loss, variables)
    return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn,# 模型
                             './tf_estimator_example/'#checkout路径与模型路径
                            )
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1837233860>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./tf_estimator_example/model.ckpt-10
WARNING:tensorflow:From /Users/ki/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py:1069: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:loss = 3.5265698, step = 11
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 12 vs previous value: 12. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 15 vs previous value: 15. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 18 vs previous value: 18. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:Saving checkpoints for 20 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Loss for final step: 33.14527.
1
2
3
4
# 恢复checkpoint
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoi