wake-up-neo.com

Tensor-Objekte können nicht wiederholt werden, wenn die Eager-Ausführung nicht aktiviert ist. Um über diesen Tensor zu iterieren, verwenden Sie `tf.map_fn`

Ich versuche meine eigene Verlustfunktion zu erstellen:

def custom_mse(y_true, y_pred):
    tmp = 10000000000
    a = list(itertools.permutations(y_pred))
    for i in range(0, len(a)): 
     t = K.mean(K.square(a[i] - y_true), axis=-1)
     if t < tmp :
        tmp = t
     return tmp

Es sollte Permutationen des vorhergesagten Vektors erzeugen und den kleinsten Verlust zurückgeben.

   "`Tensor` objects are not iterable when eager execution is not "
TypeError: `Tensor` objects are not iterable when eager execution is not enabled. To iterate over this tensor use `tf.map_fn`.

error. Ich kann keine Quelle für diesen Fehler finden. Warum passiert das?

Danke für deine Hilfe.

10
Darlyn

Der Fehler tritt auf, weil y_pred ist ein Tensor (ohne eifrige Ausführung nicht iterierbar), und itertools.permutations erwartet einen iterierbaren, aus dem Permutationen erstellt werden können. Außerdem würde der Teil, in dem Sie den minimalen Verlust berechnen, ebenfalls nicht funktionieren, da die Werte von Tensor t zum Zeitpunkt der Diagrammerstellung unbekannt sind.

Anstatt den Tensor zu permutieren, würde ich Permutationen der Indizes erstellen (dies können Sie zum Zeitpunkt der Diagrammerstellung tun) und dann die permutierten Indizes vom Tensor erfassen. Angenommen, Ihr Keras-Backend ist TensorFlow und y_true/y_pred sind 2-dimensional, Ihre Verlustfunktion könnte wie folgt implementiert werden:

def custom_mse(y_true, y_pred):
    batch_size, n_elems = y_pred.get_shape()
    idxs = list(itertools.permutations(range(n_elems)))
    permutations = tf.gather(y_pred, idxs, axis=-1)  # Shape=(batch_size, n_permutations, n_elems)
    mse = K.square(permutations - y_true[:, None, :])  # Shape=(batch_size, n_permutations, n_elems)
    mean_mse = K.mean(mse, axis=-1)  # Shape=(batch_size, n_permutations)
    min_mse = K.min(mean_mse, axis=-1)  # Shape=(batch_size,)
    return min_mse
10
rvinas