Page Menu
Home
Phabricator (Chris)
Search
Configure Global Search
Log In
Files
F118092
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Flag For Later
Award Token
Authored By
Unknown
Size
4 KB
Referenced Files
None
Subscribers
None
View Options
diff --git a/tools/export_graph.py b/tools/export_graph.py
index 21a4fda..8381016 100644
--- a/tools/export_graph.py
+++ b/tools/export_graph.py
@@ -1,113 +1,116 @@
import os
import sys
import argparse
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
sys.path.append((os.path.normpath(
os.path.join(os.path.dirname(os.path.realpath(__file__)),
'..'))))
from model import OpenNsfwModel, InputType
"""Exports the graph so it can be imported via import_graph_def
The exported model takes an base64 encoded string tensor as input
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("target", help="output directory")
parser.add_argument("-m", "--model_weights", required=True,
help="Path to trained model weights file")
- parser.add_argument("-o", "--optimize", required=False, default=True,
+ parser.add_argument("-o", "--optimize", action='store_true',
+ default=False,
help="Optimize graph for inference")
- parser.add_argument("-f", "--freeze", required=False, default=True,
+ parser.add_argument("-f", "--freeze", action='store_true',
+ required=False, default=False,
help="Freeze graph: convert variables to ops")
- parser.add_argument("-b", "--binary", required=False, default=True,
+ parser.add_argument("-t", "--text", action='store_true',
+ required=False, default=False,
help="Write graph as binary (.pb) or text (pbtext)")
args = parser.parse_args()
model = OpenNsfwModel()
export_base_path = args.target
do_freeze = args.freeze
do_optimize = args.optimize
as_binary = args.binary
input_node_name = 'input'
output_node_name = 'predictions'
base_name = 'open_nsfw'
checkpoint_path = os.path.join(export_base_path, base_name + '.ckpt')
if as_binary:
graph_name = base_name + '.pb'
else:
graph_name = base_name + '.pbtxt'
graph_path = os.path.join(export_base_path, graph_name)
frozen_graph_path = os.path.join(export_base_path,
'frozen_' + graph_name)
optimized_graph_path = os.path.join(export_base_path,
'optimized_' + graph_name)
with tf.Session() as sess:
model.build(weights_path=args.model_weights,
input_type=InputType.BASE64_JPEG)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, save_path=checkpoint_path)
print('Checkpoint exported to {}'.format(checkpoint_path))
tf.train.write_graph(sess.graph_def, export_base_path, graph_name,
as_text=not as_binary)
print('Graph exported to {}'.format(graph_path))
if do_freeze:
print('Freezing graph...')
freeze_graph.freeze_graph(
input_graph=graph_path, input_saver='',
input_binary=as_binary, input_checkpoint=checkpoint_path,
output_node_names=output_node_name,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph=frozen_graph_path, clear_devices=True,
initializer_nodes='')
print('Frozen graph exported to {}'.format(frozen_graph_path))
graph_path = frozen_graph_path
if do_optimize:
print('Optimizing graph...')
input_graph_def = tf.GraphDef()
with tf.gfile.Open(graph_path, 'rb') as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def =\
optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
[input_node_name],
[output_node_name],
tf.float32.as_datatype_enum)
f = tf.gfile.FastGFile(optimized_graph_path, 'wb')
f.write(output_graph_def.SerializeToString())
print('Optimized graph exported to {}'
.format(optimized_graph_path))
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Fri, May 15, 8:47 AM (13 h, 49 m ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
63890
Default Alt Text
(4 KB)
Attached To
Mode
R62 open_nsfw
Attached
Detach File
Event Timeline