tensorflow保存模型

Saved_model模块

tfs

saved_model_cli show --dir ./6 --all

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# coding=utf-8
import tensorflow as tf
import os
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.framework import graph_util

# output_node_names = "whichPun/output:0"

with tf.Session() as sess:
saver = tf.train.import_meta_graph('model/-1.meta')
saver.restore(sess, tf.train.latest_checkpoint("model/"))
graph = sess.graph

graph_def = sess.graph.as_graph_def()

input_x = sess.graph.get_tensor_by_name("inputs/input_ids:0")
print(input_x)
input_mask = sess.graph.get_tensor_by_name("inputs/input_mask:0")
print(input_mask)
# bert/encoder/layer_7/attention/output/LayerNorm/batchnorm/add_1:0 att
# bert/encoder/layer_7/intermediate/dense/mul_3:0 inter
# bert/encoder/layer_7/output/LayerNorm/batchnorm/add_1:0 output
# bert/encoder/layer_7/attention/self/Reshape_5:0 self_att
# bert/encoder/layer_7/attention/self/Reshape_3:0 att_probs
automark_out = sess.graph.get_tensor_by_name("whichPun/whichPun_output:0")
smooth_out = sess.graph.get_tensor_by_name("smooth/smooth_output:0")
print(automark_out)
print(smooth_out)

# sess.run(graph.get_operation_by_name('Inputs/string_to_index/hash_table/table_init'))
# sess.run(graph.get_operation_by_name('Inputs/string_to_index_1/hash_table/table_init'))
export_path = 'saved_model'
# 保存图模型
if export_path:
os.system("rm -rf " + export_path)

builder = tf.saved_model.builder.SavedModelBuilder(export_path)
data_in = tf.saved_model.utils.build_tensor_info(input_x)
data_mask = tf.saved_model.utils.build_tensor_info(input_mask)
data_punc_out = tf.saved_model.utils.build_tensor_info(automark_out)
data_smooth_output = tf.saved_model.utils.build_tensor_info(smooth_out)

# table_init = tf.group(tf.tables_initializer(), name='legacy_init_op')
# signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': data_in, 'mask': data_mask},
outputs={'smooth_output': data_smooth_output, 'punc_output': data_punc_out},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
# 导入graph与变量信息
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'ac_forward': prediction_signature,
})

builder.save()
exit()

frozen_graph_def = graph_util.convert_variables_to_constants(sess,
graph_def,
["whichPun/whichPun_output", "smooth/smooth_output"])
# 模型保存成.pb格式
with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
f.write(frozen_graph_def.SerializeToString())

exit()
'''
# writer = tf.summary.FileWriter("logs/", sess.graph)
for op in graph.get_operations():
print(op.name)

# 固化模型
frozen_graph_def = graph_util.convert_variables_to_constants(sess,
tf.get_default_graph().as_graph_def(),
[output_node_names])
# 模型保存成.pb格式
# with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
# f.write(frozen_graph_def.SerializeToString())

with tf.Graph().as_default() as graph:
tf.import_graph_def(frozen_graph_def, name="", )
with tf.Session() as sess:
graph = sess.graph
# for op in graph.get_operations():
# print(op.name)

input_x = sess.graph.get_tensor_by_name("inputs/input_ids")
input_leng = sess.graph.get_tensor_by_name("inputs/input_mask:0")
final_out = sess.graph.get_tensor_by_name("whichPun/output:0")
print('out1', final_out)

export_path = "saved_model"
# 保存图模型
if export_path:
os.system("rm -rf " + export_path)

builder = tf.saved_model.builder.SavedModelBuilder(export_path)
data_in = tf.saved_model.utils.build_tensor_info(input_x)
data_length = tf.saved_model.utils.build_tensor_info(input_leng)
data_out = tf.saved_model.utils.build_tensor_info(final_out)

# signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'ids': data_in, 'mask': data_length},
outputs={'output': data_out},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
# 导入graph与变量信息
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'ac_forward': prediction_signature,
})

builder.save()
'''