wake-up-neo.com

Länge eines Datensatzes in Tensorflow abrufen

source_dataset = tf.data.TextLineDataset('primary.csv')
target_dataset = tf.data.TextLineDataset('secondary.csv')
dataset = tf.data.Dataset.Zip((source_dataset, target_dataset))
dataset = dataset.shard(10000, 0)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                              tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([start_token], target), axis=0), tf.concat((target, [end_token]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, tf.size(source), target_in, target_out, tf.size(target_in)))

dataset = dataset.shuffle(NUM_SAMPLES)  #This is the important line of code

Ich möchte mein gesamtes Dataset vollständig mischen, aber shuffle() erfordert eine Reihe von Samples, um zu ziehen, und tf.Size() arbeitet nicht mit tf.data.Dataset.

Wie kann ich richtig mischen?

6
Evan Weissburg

Ich arbeitete mit tf.data.FixedLengthRecordDataset () und hatte ein ähnliches Problem ... In meinem Fall versuchte ich, nur einen bestimmten Prozentsatz der Rohdaten zu nehmen Feste Länge, ein Workaround für mich war:

totalBytes = sum([os.path.getsize(os.path.join(filepath, filename)) for filename in os.listdir(filepath)])
numRecordsToTake = tf.cast(0.01 * percentage * totalBytes / bytesPerRecord, tf.int64)
dataset = tf.data.FixedLengthRecordDataset(filenames, recordBytes).take(numRecordsToTake)

In Ihrem Fall würde ich vorschlagen, die Anzahl der Datensätze in 'primary.csv' und 'secondary.csv' direkt in Python zu zählen. Alternativ denke ich zu Ihrem Zweck, das Argument buffer_size zu setzen, erfordert nicht wirklich das Zählen der Dateien. Laut der akzeptierten Antwort zur Bedeutung von buffer_size sorgt eine Zahl, die größer ist als die Anzahl der Elemente in der Datenmenge, für eine einheitliche Verteilung der Daten in der gesamten Datenmenge. Es sollte also funktionieren, wenn Sie eine wirklich große Zahl eingeben (die Ihrer Meinung nach die Größe des Datensatzes übertrifft).

1
Ringo