tensorflow的三种保存格式相互转换

首先三种模型导出:

  • tf.train.Saver()

用于保存和恢复Variable。它可以非常方便的保存当前模型的变量或者倒入之前训练好的变量。一个最简单的运用:

1
2
3
4
5
saver = tf.train.Saver()
# Save the variables to disk.
saver.save(sess, "/tmp/test.ckpt")
# Restore variables from disk.
saver.restore(sess, "/tmp/test.ckpt")

1. ckpt格式

1
#saver.save(sess, '../tf-model/', global_step=1, write_meta_graph=True)

2. Pb格式

1
2
3
4
5
with tf.variable_scope("whichPun"):
task_2 = tf.layers.dense(bert_output, units=5, activation=None, trainable=False)
print("bert_output== ", task_2)
task_2 = tf.cast(task_2, tf.float32)
self.logit = tf.reshape(task_2, [-1, self.input_shape[1], 5], name='output')

在输出的scope离找到输出名字

1
2
3
4
5
6
 #2. 保存为pb 在sess中两行
frozen_graph_def = graph_util.convert_variables_to_constants(sess,
tf.get_default_graph().as_graph_def(),
['whichPun/output'] )#注意此处是输出名字 为list才可以
with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
f.write(frozen_graph_def.SerializeToString())

3. tfs格式Saved_model模块

saved_model_cli show --dir ./6 --all

Exporter 的基本使用方式是:

1)传入一个Saver实例;

2)调用init,定义模型的graph以及input/output

3)使用Exporter导出模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
        #3. 保存为tfs modle
with tf.Graph().as_default() as graph:
tf.import_graph_def(frozen_graph_def, name="", )
with tf.Session() as sess:
export_path = "savedmodel"
if export_path:
os.system("rm -rf " + export_path)
# 恢复指定的tensor
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
inids = tf.saved_model.utils.build_tensor_info(model.input_ids)
inmask = tf.saved_model.utils.build_tensor_info(model.input_mask)
poutput = tf.saved_model.utils.build_tensor_info(model.logit)


# signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': inids, 'mask': inmask},
outputs={'punc_output': poutput},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
# 导入graph与变量信息
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'ac_forward': prediction_signature,
})

builder.save()
#

模型相互转换

ckpt2pb.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf
from sys import argv

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.import_meta_graph(argv[1] + '-1.meta', clear_devices=True)
ckpt_model_path = argv[1]
saver.restore(sess, tf.train.latest_checkpoint(ckpt_model_path))

graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
input_graph_def,
# ['smooth/smooth_output', 'whichPun/whichPun_output'] # We split on comma for convenience
#['smooth/output', 'whichPun/output'] # We split on comma for convenience
#['smooth/smooth_output'] # We split on comma for convenience
#['whichPun/whichPun_output'] # We split on comma for convenience
['whichPun/output'] # We split on comma for convenience
)
# # Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(argv[2], "wb") as f:
f.write(output_graph_def.SerializeToString())

pb2tfs.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#! encoding: utf-8

import numpy as np
from tensorflow.python.platform import gfile
import time
import os
import datetime
import tensorflow as tf
from sys import argv

from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.framework import graph_util

g1 = tf.Graph()
with g1.as_default() as g1:
output_graph_def = tf.GraphDef()
with gfile.FastGFile(argv[1], "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")

sess = tf.Session(graph=g1)

input = sess.graph.get_tensor_by_name("inputs/input_ids:0")
word_size = sess.graph.get_tensor_by_name("inputs/input_mask:0")
output = sess.graph.get_tensor_by_name("whichPun/output:0")
print(output)



tf.import_graph_def(output_graph_def, name="", )
with tf.Session() as sess:
# 保存图模型
export_path = "savedmodel"
if export_path:
os.system("rm -rf " + export_path)

builder = tf.saved_model.builder.SavedModelBuilder(export_path)
data_in = tf.saved_model.utils.build_tensor_info(input)
data_in2 = tf.saved_model.utils.build_tensor_info(word_size)
data_out = tf.saved_model.utils.build_tensor_info(output)

#signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': data_in, 'mask':data_in2},
outputs={'output': data_out},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

#导入graph与变量信息
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'ac_forward': prediction_signature,
})

builder.save()

os.system("chmod -R 755 " + export_path)

graph2tfs.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# coding=utf-8
import tensorflow as tf
import os
from sys import argv
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.framework import graph_util

# output_node_names = "whichPun/output:0"

with tf.Session() as sess:
saver = tf.train.import_meta_graph(argv[1] + '/-1.meta')
saver.restore(sess, tf.train.latest_checkpoint(argv[1]))
graph = sess.graph

graph_def = sess.graph.as_graph_def()

input_x = sess.graph.get_tensor_by_name("inputs/input_ids:0")
print(input_x)
input_mask = sess.graph.get_tensor_by_name("inputs/input_mask:0")
print(input_mask)

punc_out = sess.graph.get_tensor_by_name("whichPun/output:0")
#smooth_out = sess.graph.get_tensor_by_name("smooth/output:0")


# sess.run(graph.get_operation_by_name('Inputs/string_to_index/hash_table/table_init'))
export_path = 'saved_model_no_freeze'
# 保存图模型
if export_path:
os.system("rm -rf " + export_path)

builder = tf.saved_model.builder.SavedModelBuilder(export_path)
data_in = tf.saved_model.utils.build_tensor_info(input_x)
data_mask = tf.saved_model.utils.build_tensor_info(input_mask)
data_out_1 = tf.saved_model.utils.build_tensor_info(punc_out)
# data_out_2 = tf.saved_model.utils.build_tensor_info(smooth_out)

# table_init = tf.group(tf.tables_initializer(), name='legacy_init_op')
# signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': data_in, 'mask': data_mask},
outputs={'output': data_out_1},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
# 导入graph与变量信息
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'ac_forward': prediction_signature,
})

builder.save()
exit()

frozen_graph_def = graph_util.convert_variables_to_constants(sess,
graph_def,
["whichPun/output"])
# 模型保存成.pb格式
with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
f.write(frozen_graph_def.SerializeToString())

exit()

# writer = tf.summary.FileWriter("logs/", sess.graph)
for op in graph.get_operations():
print(op.name)

Tfs模型固化

sess里冻结图,然后import_graph_def()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# 固化模型
frozen_graph_def = graph_util.convert_variables_to_constants(sess,
tf.get_default_graph().as_graph_def(),
[output_node_names])
# 模型保存成.pb格式
# with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
# f.write(frozen_graph_def.SerializeToString())

with tf.Graph().as_default() as graph:
tf.import_graph_def(frozen_graph_def, name="", )
with tf.Session() as sess:
graph = sess.graph
# for op in graph.get_operations():
# print(op.name)

input_x = sess.graph.get_tensor_by_name("inputs/input_ids")
input_leng = sess.graph.get_tensor_by_name("inputs/input_mask:0")
final_out = sess.graph.get_tensor_by_name("whichPun/output:0")
print('out1', final_out)

export_path = "saved_model"
# 保存图模型
if export_path:
os.system("rm -rf " + export_path)

builder = tf.saved_model.builder.SavedModelBuilder(export_path)
data_in = tf.saved_model.utils.build_tensor_info(input_x)
data_length = tf.saved_model.utils.build_tensor_info(input_leng)
data_out = tf.saved_model.utils.build_tensor_info(final_out)

# signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'ids': data_in, 'mask': data_length},
outputs={'output': data_out},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
# 导入graph与变量信息
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'ac_forward': prediction_signature,
})

builder.save()