如何在 TensorFlow 中加载保存的模型?

数据挖掘 机器学习 Python 张量流 迁移学习
2022-03-13 17:59:49

这是我在 Python 中的代码:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np

我使用以下代码检查了保存的模型是否存在:

tf.compat.v1.saved_model.contains_saved_model(
    '/Link_to_the_saved_model_directory/'
)

返回True
,据我了解,我可以使用以下代码进一步确保模型正确保存:

tf.saved_model.Asset(
    '/Link_to_the_saved_model_directory/'
)

它返回这个:

<tensorflow.python.training.tracking.tracking.Asset at 0x2aad125e5710>

所以,一切看起来都很好。但是,当我使用以下脚本加载模型时,出现错误。

LoadedModel = tf.saved_model.load(
    export_dir='/Link_to_the_saved_model_directory/', tags=None
)

错误:

OSError: Cannot parse file b'/Link_to_the_saved_model_directory/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

/Link_to_the_saved_model_directory/ 中的文件如下所示:

['saved_model.pbtxt',
 'saved_model.ckpt-0.data-00000-of-00001',
 'saved_model.ckpt-0.meta',
 'checkpoint',
 'saved_model.ckpt-0.index']

非常感谢任何关于如何加载模型的建议,以便我可以重用它进行迁移学习。也可能是模型部分是用以前版本的 TensorFlow(例如 TensorFlow 1.x)编写的,因此这个错误是由于兼容性问题造成的,但我还没有找到解决方案。

更新:我尝试了以下方法,但它不起作用(tf 是使用兼容版本导入的import tensorflow.compat.v1. as tf):

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/dir_to_the_model_files/saved_model.ckpt-0.meta')
    saver.restore(sess, "/dir_to_the_model_files/saved_model.ckpt-0")
    loaded = tf.saved_model.load(sess,tags=None,export_dir="/dir_to_the_model_files",import_scope=None)

它返回以下警告和错误:

WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'metric_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from /dir_to_the_model_files/saved_model.ckpt-0
<tensorflow.python.training.saver.Saver object at 0x2aaab4824a50>
WARNING:tensorflow:From <ipython-input-3-b8fd24f6b841>:9: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..
1个回答

TensorFlow 文档tf.saved_model.load可能会有所帮助:

来自 tf.estimator.Estimator 或 1.x SavedModel API 的 SavedModel 具有平面图,而不是 tf.function 对象。这些 SavedModel 将具有与其在 .signatures 属性中的签名相对应的函数,而且还有一个 .prune 方法,允许您为新子图提取函数。这相当于在 TensorFlow 1.x 的会话中导入 SavedModel 并命名提要和提取。

您可能必须使用已弃用的 v1 api 调用 https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/load