42

Próbuję przypisać nową wartość do zmiennej tensorflow w python.Jak przypisać wartość zmiennej TensorFlow?

import tensorflow as tf 
import numpy as np 

x = tf.Variable(0) 
init = tf.initialize_all_variables() 
sess = tf.InteractiveSession() 
sess.run(init) 

print(x.eval()) 

x.assign(1) 
print(x.eval()) 

Ale wyjście mogę to

0 
0 

Więc wartość nie uległa zmianie. czego mi brakuje?

Odpowiedz

72

Oświadczenie x.assign(1) faktycznie nie przypisujemy wartość 1 do x, ale raczej tworzy tf.Operation że trzeba wyraźnie run zaktualizować zmienną * Wywołanie Operation.run() lub Session.run() mogą być wykorzystane do uruchomienia operacji.:

assign_op = x.assign(1) 
sess.run(assign_op) # or `assign_op.op.run()` 
print(x.eval()) 
# ==> 1 

(* W rzeczywistości, to zwraca tf.Tensor, co odpowiada zaktualizowanej wartości zmiennej, aby ułatwić zadania łańcuchowych.)

+0

Dzięki! assign_op.run() daje błąd: AttributeError: Obiekt 'Tensor' nie ma atrybutu 'run'. Ale sess.run (assign_op) działa idealnie dobrze. – abora

+0

W tym przykładzie, czy dane "Zmienna" 'x przechowywane w pamięci przed operacją' assign'/zmienny tensor zostały nadpisane lub czy utworzono nowy tensor, który przechowuje zaktualizowaną wartość? – dannygoldstein

+3

Obecna implementacja 'assign()' nadpisuje istniejącą wartość. – mrry

-4

Jest łatwiejszy podejście:

x = tf.Variable(0) 
x = x + 1 
print x.eval() 
+2

o.p. badał użycie 'tf.assign', a nie dodawania. – vega

6

Przede wszystkim można przypisać wartości do zmiennych/stałych właśnie przez karmienie wartości do nich w ten sam sposób to zrobić z symbolami. Więc jest to całkowicie legalne zrobić:

import tensorflow as tf 
x = tf.Variable(0) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print sess.run(x, feed_dict={x: 3}) 

Odnośnie Twojego zamieszanie z operatorem tf.assign(). W TF nic nie jest wykonywane przed uruchomieniem go wewnątrz sesji. Więc zawsze musisz zrobić coś takiego: op_name = tf.some_function_that_create_op(params), a następnie w trakcie sesji przeprowadzasz sess.run(op_name). Korzystanie przypisać jako przykład można zrobić coś takiego:

import tensorflow as tf 
x = tf.Variable(0) 
y = tf.assign(x, 1) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print sess.run(x) 
    print sess.run(y) 
    print sess.run(x) 
+1

Należy zauważyć, że podawanie wartości za pośrednictwem 'feed_dict' nie przypisuje tej wartości do zmiennej. –

2

Ponadto, należy zauważyć, że jeśli używasz your_tensor.assign(), wówczas tf.global_variables_initializer nie musi być wywołana jawnie ponieważ operacja przypisać zrobi to za Ciebie w tle.

Przykład:

In [212]: w = tf.Variable(12) 
In [213]: w_new = w.assign(34) 

In [214]: with tf.Session() as sess: 
    ...:  sess.run(w_new) 
    ...:  print(w_new.eval()) 

# output 
34 

Jednak to nie będzie zainicjować wszystkie zmienne, ale będzie to tylko zainicjować zmienną, na którym assign został stracony w dniu.

2

Można również przypisać nową wartość do tf.Variable bez dodawania operacji do wykresu: tf.Variable.load(value, session). Ta funkcja może również oszczędzać dodawanie symboli zastępczych podczas przypisywania wartości spoza wykresu i jest przydatna w przypadku sfinalizowania wykresu.

import tensorflow as tf 
x = tf.Variable(0) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run(x)) # Prints 0. 
x.load(1, sess) 
print(sess.run(x)) # Prints 1.