看完文章將會學到
- 如何在 TensorFlow 2 將訓練好的 Model 保存下來
- 如何在 TensorFlow 2 讀取之前保存下來的 Model
Github 程式碼
Save Model 保存模型
在 Tensorflow 裡面,Save Model 可以代表兩件事:
Checkpoints
Checkpoints 包含了在 Model 裡面每個參數( tf.Variable 物件)的 “值”,但是不包含這個 Model 的架構,因此會使用 Checkpoints 大多是用在 “source code” 可以取得的時候,可以重建出整個 Model 的情況
SavedModel
SavedModel 除了參數的值之外,還包含了整個 Model 的結構,因此非常神奇非常方便的地方是: 在電腦上 train 好的 Model 可以透過 SavedModel 的方法 deploy 到其他平台上面,包含TensorFlow Serving, TensorFlow Lite, TensorFlow.js…甚至是使用其他的語言 (C, C++, Java, Go, Rust, C# 等等)請見下圖
SavedModel 範例
在這裡最重要的 API 就是:
tf.saved_model.save(
obj,
export_dir,
signatures=None
)
tf.saved_model.load(
export_dir,
tags=None
)
在這個 Save Model 的教學裡面,我將會使用我們之前的數字辨識模型,請參考:
TensorFlow 2 教學:Keras–MNIST–自訂模型
在訓練完我們的模型之後 , 我們可以將他保存下來:
儲存 model,會產生一個叫做 “Grandma” 的資料夾來放 model
tf.saved_model.save(model, "Grandma")
從儲存的 “Grandma” 資料夾讀取 model 並取名叫 new_model
new_model = tf.saved_model.load("grandma")
定義一個evaluate model 的 function
def evaluate(model):
for images, labels in test_dataset:
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
print("test loss : {}".format(test_loss.result()))
print("test accuracy : {}".format(test_accuracy.result()))
test_loss.reset_states()
test_accuracy.reset_states()
測試原來 model的 loss 及 accuracy
evaluate(model)
test loss : 0.05755476520098324
test accuracy : 0.9846
測試剛剛讀取 model的 loss 及 accuracy
evaluate(new_model)
test loss : 0.05755476520098324
test accuracy : 0.9846
進階學習
TensorFlow doc : Training checkpoints
TensorFlow doc : Using the SavedModel format