2016-07-06 18 views
11

Chcę zobaczyć zmienne, które są zapisywane w punkcie kontrolnym tensorflow wraz z ich wartościami. Jak znaleźć nazwy zmiennych zapisane w punkcie kontrolnym tensorflow?Jak znaleźć nazwy zmiennych zapisane w punkcie kontrolnym tensorflow?

Edycja:

kiedyś tf.train.NewCheckpointReader co wynika here. Ale nie jest to podane w dokumentacji tensorflow. Czy jest jakiś inny sposób?

`

import tensorflow as tf 
    v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0") 
    v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32, 
        name="v1") 
    init_all_op = tf.initialize_all_variables() 
    save = tf.train.Saver({"v0": v0, "v1": v1}) 
    checkpoint_path = os.path.join(model_dir, "model.ckpt")  

    with tf.Session() as sess: 
     sess.run(init_all_op) 
     # Saves a checkpoint.  
     save.save(sess, checkpoint_path) 

     # Creates a reader. 
     reader = tf.train.NewCheckpointReader(checkpoint_path) 
     print('reder:\n', reader) 

     # Verifies that the tensors exist. 
     print('is exist v0?', reader.has_tensor("v0")) 
     print('is exist v1?', reader.has_tensor("v1")) 

     # Verifies that debug string contains the right strings. 
     debug_string = reader.debug_string() 
     print('\n All Variables: \n', debug_string) 

     # Verifies get_variable_to_shape_map() returns the correct information. 
     var_map = reader.get_variable_to_shape_map() 
     print('\n All Variables information :\n', var_map) 

     # Verifies get_tensor() returns the tensor value. 
     v0_tensor = reader.get_tensor("v0") 
     v1_tensor = reader.get_tensor("v1") 
     print('\n returns the v0 tensor value:\n', v0_tensor) 
     print('\n returns the v1 tensor value:\n', v1_tensor) 

`

+0

Widziałem, że przyjąłeś odpowiedź. Zatem, jaki jest kod, który napisałeś, aby uruchomić funkcję 'print_tensors_in_checkpoint_file?' Próbowałem użyć tego, ale za każdym razem, gdy robię 'tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file' python mówi, że moduł' tensorflow.python' nie ma atrybut "narzędzia". Myślę, że byłoby niezmiernie pomocne, gdybyś przedstawił mały przykładowy skrypt, jak uruchomić tę funkcję (ponieważ ten plik również nie dostarcza przykładu), zwłaszcza, że ​​zaakceptowałeś odpowiedź, więc zakładam, że coś zadziałało. – Pinocchio

Odpowiedz

4

Można użyć narzędzia inspect_checkpoint.py.

+2

Próbowałem użyć tego, ale gdy robię 'tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file' python mówi, że moduł' tensorflow.python' nie ma atrybutu 'tools'. Myślę, że ti byłby niezmiernie pomocny, gdybyś podał mały przykładowy skrypt, jak uruchomić tę funkcję (ponieważ ten plik nie dostarcza również przykładu). – Pinocchio

19

Przykład użycia:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80] 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='') 

# List contents of v0 tensor. 
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0') 

# List contents of v1 tensor. 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1') 

Aktualizacja:all_tensors argument został dodany do print_tensors_in_checkpoint_file od Tensorflow 0.12.0-rc0 więc może trzeba dodać all_tensors=False lub all_tensors=True razie potrzeby.

metoda alternatywna:

from tensorflow.python import pywrap_tensorflow 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 
var_to_shape_map = reader.get_variable_to_shape_map() 
for key in var_to_shape_map: 
    print("tensor_name: ", key) 
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names 

Nadzieję, że to pomaga.

+0

Bardzo pomocna, dziękuję! – allen

1

Dodając do powyższego odpowiedź:

Jeśli model jest zapisywany w formacie V2

model-10000.data-00000-of-00001 
model-10000.index 
model-10000.meta 

Twój checkpoint nazwa wejściowy powinien być tylko przedrostek

print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True) 

źródło: przez @LingjiaDeng na https://github.com/tensorflow/tensorflow/issues/7696