2016-10-27 11 views
6

Próbuję rozpocząć pracę z TensorFlow w Pythonie, budując prosty NN z przekazywaniem. Mam jedną klasę, która przechowuje masy sieci (zmienne, które są aktualizowane podczas pociągu, i mają pozostać stałe dla środowiska wykonawczego) i inny skrypt do szkolenia sieci, która pobiera dane treningowe, rozdziela je na partie i trenuje sieć w partiach . Kiedy staram się trenować sieci, pojawia się błąd wskazujący, że tensor dane nie są w tym samym wykresie co tensorów NN:TensorFlow: Jak zapewnić, że tensory są na tym samym wykresie

ValueError: Tensor("Placeholder:0", shape=(10, 5), dtype=float32) must be from the same graph as Tensor("windows/embedding/Cast:0", shape=(100232, 50), dtype=float32).

odpowiednich części w skrypcie szkoleniowym są:

def placeholder_inputs(batch_size, ner): 
    windows_placeholder = tf.placeholder(tf.float32, shape=(batch_size, ner.windowsize)) 
    labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) 
    return windows_placeholder, labels_placeholder 

with tf.Session() as sess: 
    windows_placeholder, labels_placeholder = placeholder_inputs(batch_size, ner) 
    logits = ner.inference(windows_placeholder) 

A istotne w klasie sieci są:

class WindowNER(object): 
def __init__(self, wv, windowsize=3, dims=[None, 100,5], reg=0.01): 
    self.reg=reg 
    self.windowsize=windowsize 
    self.vocab_size = wv.shape[0] 
    self.embedding_dim = wv.shape[1] 
    with tf.name_scope("embedding"): 
     self.L = tf.cast(tf.Variable(wv, trainable=True, name="L"), tf.float32) 
    with tf.name_scope('hidden1'): 
     self.W = tf.Variable(tf.truncated_normal([windowsize * self.embedding_dim, dims[1]], 
      stddev=1.0/math.sqrt(float(windowsize*self.embedding_dim))), 
     name='weights') 
     self.b1 = tf.Variable(tf.zeros([dims[1]]), name='biases') 
    with tf.name_scope('output'): 
     self.U = tf.Variable(tf.truncated_normal([dims[1], dims[2]], stddev = 1.0/math.sqrt(float(dims[1]))), name='weights') 
     self.b2 = tf.Variable(tf.zeros(dims[2], name='biases')) 


def inference(self, windows): 
    with tf.name_scope("embedding"): 
     embedded_words = tf.reshape(tf.nn.embedding_lookup(self.L, windows), [windows.get_shape()[0], self.windowsize * self.embedding_dim]) 
    with tf.name_scope("hidden1"): 
     h = tf.nn.tanh(tf.matmul(embedded_words, self.W) + self.b1) 
    with tf.name_scope('output'): 
     t = tf.matmul(h, self.U) + self.b2 

Dlaczego istnieją dwa wykresy w pierwszej kolejności, i jak mogę się upewnić, że tensory zastępcze danych są na tym samym wykresie co NN?

Dzięki!

Odpowiedz

5

powinien być w stanie utworzyć wszystkie tensory pod tym samym wykresie robiąc coś takiego:

g = tf.Graph() 
with g.as_default(): 
    windows_placeholder, labels_placeholder = placeholder_inputs(batch_size, ner) 
    logits = ner.inference(windows_placeholder) 

with tf.Session(graph=g) as sess: 
    # Run a session etc 

można przeczytać więcej na temat wykresów w TF tutaj: https://www.tensorflow.org/versions/r0.8/api_docs/python/framework.html#Graph

+0

Dzięki za szybką odpowiedź! Jednak dokonałem tej zmiany (komentując sesję, dopóki nie otrzymam poprawnie zbudowanego wykresu) i nadal otrzymuję ten sam błąd - "Tensor (...) musi pochodzić z tego samego wykresu co Tensor (...)". – user616254

+0

Trudno powiedzieć, nie widząc całego kodu. Ale wydaje się prawdopodobne, że masz kod, który konstruuje operatorów poza zasięgiem 'with g.as_default()' lub jakiś kod, który wywołujesz, tworzy własny wykres. Czy mógłbyś pokazać więcej kodu? (Mówiąc szczerze, następną rzeczą, którą chciałbym wypróbować, jest oprzyrządowanie kodu Pythona Tensorflow, który konstruuje operatorów i drukuje tożsamość wykresu, do którego dodawany jest każdy operator.) –

0

Czasami, gdy masz Błąd taki jak ten, błąd (który często może być użyty z niewłaściwą zmienną z innego wykresu) mógł wystąpić dużo wcześniej i propagowany do operacji, która ostatecznie rzuciła błąd. W związku z tym możesz zbadać tylko tę linię i stwierdzić, że tensory powinny pochodzić z tego samego wykresu, podczas gdy błąd w rzeczywistości leży gdzie indziej.

Najprostszym sposobem sprawdzenia jest wydrukowanie wykresu dla każdej zmiennej/operacji na wykresie. Możesz to zrobić po prostu:

print(variable_name.graph)