2016-01-26 7 views
16

Mam konfigurację, w której muszę zainicjować LSTM po głównej inicjalizacji, która używa tf.initialize_all_variables(). To znaczy. Chcę zadzwonić tf.initialize_variables([var_list])Tensorflow: Jak uzyskać wszystkie zmienne z rnn_cell.BasicLSTM i rnn_cell.MultiRNNCell

Czy istnieje sposób, aby zebrać wszystkie wewnętrzne zmienne wyszkolić dla obu:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

tak, że mogę zainicjować WŁAŚNIE te parametry?

Głównym powodem tego chcę, ponieważ nie chcę ponownie zainicjować niektórych wyszkolonych wartości wcześniej.

Odpowiedz

17

Najprostszym sposobem rozwiązania problemu jest użycie zakresu zmiennego. Nazwy zmiennych w zakresie będą poprzedzone prefiksem z jego nazwą. Oto krótki urywek:

cell = rnn_cell.BasicLSTMCell(num_nodes) 

with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    # Retrieve just the LSTM variables. 
    lstm_variables = [v for v in tf.all_variables() 
        if v.name.startswith(vs.name)] 

# [..] 
# Initialize the LSTM variables. 
tf.initialize_variables(lstm_variables) 

Działałoby to tak samo z MultiRNNCell.

EDIT: zmienił tf.trainable_variables do tf.all_variables()

+0

To jest idealne, dziękuję. Nie zdawałem sobie sprawy, że 'tf.trainable_variables()' respektuje zakres, ale myślę, że z perspektywy czasu ma to sens! – bge0

+1

Chciałbym dodać, że 'tf.all_variables()' zamiast 'tf.trainable_variables()' byłoby lepszym wyborem. Głównie dlatego, że istnieją rzeczy takie jak optymalizatory, które nie mają zmiennych, które mogą być trenowane, ale które nadal wymagają inicjalizacji. – bge0

+1

Dzięki, masz rację. Zaktualizowałem kod. –

11

Można również użyć tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes) 
with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) 

(częściowo skopiowany z odpowiedzią Rafała)

Zauważ, że ostatnia linia jest równoważna listowego w Kod Rafała.

Zasadniczo tensorflow przechowuje globalną kolekcję zmiennych, które można pobrać przez tf.all_variables() lub tf.get_collection(tf.GraphKeys.VARIABLES). Jeśli podasz scope (nazwa zakresu) w funkcji tf.get_collection(), wtedy pobierzesz tylko tensory (w tym przypadku zmienne) w kolekcji, której zakresy są w określonym zakresie.

EDYTOWANIE: Możesz także użyć tf.GraphKeys.TRAINABLE_VARIABLES, aby uzyskać tylko zmienne do trenowania. Ale ponieważ vanilla BasicLSTMCell nie inicjuje żadnej nieprzyłączalnej zmiennej, oba będą funkcjonalnie równoważne. Aby uzyskać pełną listę domyślnych kolekcji wykresów, sprawdź: this.

+0

To jest lepszy sposób niż rozwiązanie Rafała :-) –

+1

Tak jak powiedziałem powyżej, być może powinieneś lepiej użyć '' '' tf.get_collection (..., scope = vs.name + "/") 'ponieważ może istnieć inny zasięg o nazwie" LSTM2 ". – Albert