wake-up-neo.com

Speichern Sie den Tensorflow-Prüfpunkt in einer .pb-Protokolldatei

Ich habe ein pix2pix Modell auf Tensorflow trainiert und das Modell wurde in Form von Checkpoints mit den folgenden Dateien gespeichert:

model-15000.meta, model-15000.index, model-15000.data-00000-of-00001, graph.pbtxt, checkpoint.

Jetzt möchte ich es zu Bereitstellungszwecken in eine Protobuf-Datei (.pb) konvertieren. Ich bin dazu auf das Skript freeze_graph.py gestoßen, habe aber Probleme mit einem der Argumente, nämlich output_node_names.

Ich habe einige Ebenennamen ausprobiert, erhalte jedoch die folgende Fehlermeldung:

AssertionError: generator/decoder_2/batchnorm/scale/gradients ist nicht in der Grafik enthalten

Unsicher, wie die output_node_names zu finden sind

4
Blue

Versuchen Sie den folgenden Code, um Meta in Pb-Datei zu konvertieren:

import tensorflow as tf
#Step 1 
#import the model metagraph
saver = tf.train.import_meta_graph('./model.meta', clear_devices=True)
#make that as the default graph
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
sess = tf.Session()
#now restore the variables
saver.restore(sess, "./model")

#Step 2
# Find the output name
graph = tf.get_default_graph()
for op in graph.get_operations(): 
  print (op.name)

#Step 3
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util

output_node_names="predictions_mod/Sigmoid"
output_graph_def = graph_util.convert_variables_to_constants(
        sess, # The session
        input_graph_def, # input_graph_def is useful for retrieving the nodes 
        output_node_names.split(",")  )    

#Step 4
#output folder
output_fld ='./'
#output pb file name
output_model_file = 'model.pb'
from tensorflow.python.framework import graph_io
#write the graph
graph_io.write_graph(output_graph_def, output_fld, output_model_file, as_text=False)

Hoffe das klappt !!!

1
ReInvent_IO

Ich habe das gleiche Problem, wenn ich versuche, das Modell einzufrieren.

AssertionError: pose:0 is not in graph

Ich benutze dieses Skript, um alle Tensornamen auszudrucken, aber ich erhalte immer noch den Fehler.

import tensorflow as tf

from tensorflow.python.tools import inspect_checkpoint as chkp


meta_path = './data/trained_variables.ckpt.meta' # Your .meta file

with tf.Session() as sess:

# Restore the graph
saver = tf.train.import_meta_graph(meta_path)

# Load weights
saver.restore(sess,"/Users/me/Desktop/data/trained_variables.ckpt")

## Print tensors
chkp.print_tensors_in_checkpoint_file(file_name="/Users/me/Desktop/data/trained_variables.ckpt",
                                      tensor_name='',
                                      all_tensors=False,
                                      all_tensor_names=True)

Probieren Sie es aus und versuchen Sie, den richtigen Namen zu finden. Lassen Sie es mich wissen, ich stehe vor dem gleichen Problem.

0
steve