wake-up-neo.com

Tensorflow: Wie bekomme ich einen Tensor mit Namen?

Ich habe Probleme, einen Tensor beim Namen zu finden, ich weiß nicht einmal, ob es möglich ist.

Ich habe eine Funktion, die mein Diagramm erstellt:

def create_structure(tf, x, input_size,dropout):    
 with tf.variable_scope("scale_1") as scope:
  W_S1_conv1 = deep_dive.weight_variable_scaling([7,7,3,64], name='W_S1_conv1')
  b_S1_conv1 = deep_dive.bias_variable([64])
  S1_conv1 = tf.nn.relu(deep_dive.conv2d(x_image, W_S1_conv1,strides=[1, 2, 2, 1], padding='SAME') + b_S1_conv1, name="Scale1_first_relu")
.
.
.
return S3_conv1,regularizer

Ich möchte auf die Variable S1_conv1 außerhalb dieser Funktion zugreifen. Ich habe es versucht:

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('Scale1_first_relu')

Aber das gibt mir einen Fehler:

ValueError: Under-sharing: Variable scale_1/Scale1_first_relu ist nicht vorhanden, nicht zulässig. Wollten Sie reuse = None in VarScope setzen?

Das funktioniert aber:

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('W_S1_conv1')

Ich kann das mit umgehen

return S3_conv1,regularizer, S1_conv1

aber das will ich nicht.

Ich denke, mein Problem ist, dass S1_conv1 nicht wirklich eine Variable ist, sondern nur ein Tensor. Gibt es eine Möglichkeit zu tun, was ich will?

41
protas

Es gibt eine Funktion tf.Graph.get_tensor_by_name (). Zum Beispiel:

import tensorflow as tf

c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')

with tf.Session() as sess:
    test =  sess.run(e)
    print e.name #example:0
    test = tf.get_default_graph().get_tensor_by_name("example:0")
    print test #Tensor("example:0", shape=(2, 2), dtype=float32)
49
apfalz

Alle Tensoren haben Stringnamen, die Sie wie folgt sehen können

[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

Sobald Sie den Namen kennen, können Sie den Tensor mit <name>:0 abrufen (0 bezieht sich auf den Endpunkt, der etwas redundant ist).

Zum Beispiel, wenn Sie das tun

tf.constant(1)+tf.constant(2)

Sie haben die folgenden Tensor-Namen

[u'Const', u'Const_1', u'add']

Sie können also die Ausgabe der Addition abrufen

sess.run('add:0')

Beachten Sie, dass dies nicht Teil der öffentlichen API ist. Automatisch generierte String-Tensornamen sind ein Implementierungsdetail und können sich ändern.

30

Alles was Sie in diesem Fall tun müssen, ist:

ft=tf.get_variable('scale1/Scale1_first_relu:0')
0
Kislay Kunal