2017-04-06 21 views
11

Próbuję ponownie uruchomić szkolenie modelowe w TensorFlow, podnosząc miejsce, w którym zostało przerwane. Chciałbym użyć ostatnio dodanego (0.12+, myślę) import_meta_graph(), aby nie zrekonstruować wykresu.Python TensorFlow: Jak ponownie uruchomić trening z optymalizatorem i import_meta_graph?

Widziałem na to rozwiązanie, np. Tensorflow: How to save/restore a model?, ale pojawiają się problemy z AdamOptimizer, w szczególności dostaję błąd ValueError: cannot add op with name <my weights variable name>/Adam as that name is already used. This can be fixed by initializing, ale wtedy moje wartości modelu są wyczyszczone!

Istnieją inne odpowiedzi i niektóre pełne przykłady tam, ale zawsze wydają się starsze, więc nie należy uwzględniać nowszego podejścia import_meta_graph() lub nie mają optymalizatora tensorowego. Najbliższe pytanie, jakie mogłem znaleźć, to: tensorflow: saving and restoring session, ale nie ma ostatecznego rozwiązania, a przykład jest dość skomplikowany.

Idealnie chciałbym prosty przykład, zaczynając od zera, zatrzymując się, a następnie podnosząc ponownie. Mam coś, co działa (poniżej), ale zastanawiam się też, czy czegoś brakuje. Z pewnością nie tylko ja to robię?

+0

Miałem ten sam problem z AdamOptimizer. Udało mi się przekonać do działania, umieszczając moje ops w kolekcjach. Ten przykład bardzo mi pomógł: http://www.seaandsailor.com/tensorflow-checkpointing.html –

Odpowiedz

4

Oto, co wymyśliłem, czytając dokumenty, inne podobne rozwiązania oraz próbę i błąd. To prosty automatyczny koder danych losowych. Jeśli zostanie uruchomiony, a następnie uruchomiony ponownie, będzie kontynuowany od miejsca, w którym został przerwany (tj. Funkcja kosztu przy pierwszym uruchomieniu trwa od ~ 0,5 -> 0,3 sekundy uruchomienia rozpoczyna się ~ 0.3). O ile nie przeoczyłem czegoś, wszystkie oszczędności, konstruktorzy, budowanie modeli, add_to_collection są potrzebne i w dokładnej kolejności, ale może być prostszy sposób.

I tak, ładowanie wykresu import_meta_graph nie jest tu tak naprawdę potrzebne, ponieważ kod jest tuż nad, ale jest to, czego chcę w mojej aplikacji.

from __future__ import print_function 
import tensorflow as tf 
import os 
import math 
import numpy as np 

output_dir = "/root/Data/temp" 
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt") 

input_length = 10 
encoded_length = 3 
learning_rate = 0.001 
n_epochs = 10 
n_batches = 10 
if not os.path.exists(model_checkpoint_file_base + ".meta"): 
    print("Making new") 
    brand_new = True 

    x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in") 
    W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length], 
              -1.0/math.sqrt(input_length), 
              1.0/math.sqrt(input_length)), name="W_enc") 
    b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc") 
    encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded") 
    W_dec = tf.transpose(W_enc, name="W_dec") 
    b_dec = tf.Variable(tf.zeros(input_length), name="b_dec") 
    decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded") 
    cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost") 

    saver = tf.train.Saver() 
else: 
    print("Reloading existing") 
    brand_new = False 
    saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta") 
    g = tf.get_default_graph() 
    x_in = g.get_tensor_by_name("x_in:0") 
    cost = g.get_tensor_by_name("cost:0") 


sess = tf.Session() 
if brand_new: 
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) 
    init = tf.global_variables_initializer() 
    sess.run(init) 
    tf.add_to_collection("optimizer", optimizer) 
else: 
    saver.restore(sess, model_checkpoint_file_base) 
    optimizer = tf.get_collection("optimizer")[0] 

for epoch_i in range(n_epochs): 
    for batch in range(n_batches): 
     batch = np.random.rand(50, input_length) 
     _, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch}) 
     print("batch_cost:", curr_cost) 
     save_path = tf.train.Saver().save(sess, model_checkpoint_file_base) 
2

Miałem ten sam problem i właśnie zorientowałem się, co było nie tak, przynajmniej w moim kodzie.

W końcu użyłem niewłaściwego pliku w saver.restore(). Funkcja ta musi być podana nazwa pliku bez rozszerzenia pliku, podobnie jak funkcja saver.save():

saver.restore(sess, 'model-1') 

zamiast

saver.restore(sess, 'model-1.data-00000-of-00001') 

Mając to zrobić dokładnie to, co chcesz zrobić: zaczynając od zera, zatrzymanie, a następnie ponowne podniesienie. Nie muszę inicjować drugiego wygaszacza z pliku meta przy użyciu funkcji tf.train.import_meta_graph() i nie muszę jawnie podawać tf.initialize_all_variables() po zainicjowaniu optymalizatora.

Moje kompletnego modelu przywrócić wygląda następująco:

with tf.Session() as sess: 
    saver = tf.train.Saver() 
    sess.run(tf.global_variables_initializer()) 
    saver.restore(sess, model-1) 

myślę, że w protokole V1 nadal musiał dodać .ckpt do nazwy pliku, a dla import_meta_graph() trzeba jeszcze dodać .meta, co może powodować pewne zamieszanie wśród użytkowników. Być może należy to wyraźnie wskazać w dokumentacji.

0

Podczas tworzenia obiektu wygaszacza podczas sesji przywracania może występować problem.

Otrzymałem ten sam błąd, co twój, używając kodów poniżej w sesji przywracania.

saver = tf.train.import_meta_graph('tmp/hsmodel.meta') 
saver.restore(sess, tf.train.latest_checkpoint('tmp/')) 

Ale kiedy zmieniło się w ten sposób,

saver = tf.train.Saver() 
saver.restore(sess, "tmp/hsmodel") 

Błąd zniknął. "tmp/hsmodel" to ścieżka, którą podaję do save.ave (sess, "tmp/hsmodel") w sesji zapisu.

Oto proste przykłady na przechowywanie i przywracanie sesji treningowej sieci MNIST (zawierającej optymalizator Adama). Pomogło mi to w porównaniu z moim kodem i naprawienie problemu.

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py