TensorFlow 2 教學:SavedModel 儲存模型

tensorflow

看完文章將會學到

  • 如何在 TensorFlow 2 將訓練好的 Model 保存下來
  • 如何在 TensorFlow 2 讀取之前保存下來的 Model

Github 程式碼

Save Model 保存模型

在 Tensorflow 裡面,Save Model 可以代表兩件事:

Checkpoints

Checkpoints 包含了在 Model 裡面每個參數( tf.Variable 物件)的 “值”,但是不包含這個 Model 的架構,因此會使用 Checkpoints 大多是用在 “source code” 可以取得的時候,可以重建出整個 Model 的情況

SavedModel

SavedModel 除了參數的值之外,還包含了整個 Model 的結構,因此非常神奇非常方便的地方是: 在電腦上 train 好的 Model 可以透過 SavedModel 的方法 deploy 到其他平台上面,包含TensorFlow Serving, TensorFlow Lite, TensorFlow.js…甚至是使用其他的語言  (C, C++, Java, Go, Rust, C# 等等)請見下圖

Tensorflow_SavedModel

SavedModel 範例

在這裡最重要的 API 就是:

tf.saved_model.save()

tf.saved_model.load()

tf.saved_model.save(
    obj,
    export_dir,
    signatures=None
)
tf.saved_model.load(
    export_dir,
    tags=None
)

在這個 Save Model 的教學裡面,我將會使用我們之前的數字辨識模型,請參考:

TensorFlow 2 教學:Keras–MNIST–自訂模型


在訓練完我們的模型之後 , 我們可以將他保存下來:

儲存 model,會產生一個叫做 “Grandma” 的資料夾來放 model

tf.saved_model.save(model, "Grandma")

從儲存的 “Grandma” 資料夾讀取 model 並取名叫 new_model

new_model = tf.saved_model.load("grandma")

定義一個evaluate model 的 function

def evaluate(model):
    for images, labels in test_dataset:
        predictions = model(images)
        t_loss = loss_object(labels, predictions)

        test_loss(t_loss)
        test_accuracy(labels, predictions)
    
        print("test loss : {}".format(test_loss.result()))        
        print("test accuracy : {}".format(test_accuracy.result()))
    test_loss.reset_states()
    test_accuracy.reset_states()

測試原來 model的 loss 及 accuracy

evaluate(model)

test loss : 0.05755476520098324
test accuracy : 0.9846

測試剛剛讀取 model的 loss 及 accuracy

evaluate(new_model)

test loss : 0.05755476520098324
test accuracy : 0.9846

進階學習

TensorFlow doc : Training checkpoints
TensorFlow doc : Using the SavedModel format

留言討論區