Chciałbym zarządzać moim szkoleniem z tf.estimator.Estimator
, ale mam problem z używaniem go razem z interfejsem API tf.data
.Jak używać inicjalizujących iteratorów tf.data w pliku input_fn tf.estimator?
mam coś takiego:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
Ponieważ nie można używać make_one_shot_iterator
dla mojego przypadku użycia, mój problem jest to, że input_fn
zawiera iterator, który powinien zostać zainicjowany w ciągu model_fn
(tu używam tf.train.Scaffold
zainicjować operacje lokalne).
Ponadto zrozumiałem, że możemy nie tylko używać input_fn = iterator.get_next
, ponieważ inne operacje nie będą dodawane do tego samego wykresu.
Jaki jest zalecany sposób inicjowania iteratora?
@guillaumeklin - nie dodasz 'tf.add_to_collection (tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)' w input_fn()? – reese0106
Tak, możesz dodać tę linię w 'input_fn()' tuż przed 'return iterator.get_next()'. – guillaumekln