2017-07-12 33 views
8

Chcę zreorganizować węzły modelu tensorflow .pb, więc najpierw pobieram NodeDef z GraphDef i otrzymuję attr, używając NodeDef.attr(). Dla węzła "Conv2D". Mogę uzyskać parametry takie jak kroki, padding, data_format, use_cudnn_on_gpu z attr, ale nie mogę uzyskać parametrów formatu wagi. Używany przeze mnie język to C++. Jak go zdobyć! Dziękuję Ci!Jak uzyskać format wag z modelu TensorFlow .pb?

Odpowiedz

4

Conv2D posiada dwa wejścia: pierwszy jest dane, a drugi jest filter (lub wagi), więc można po prostu sprawdzić format drugiego wejścia Conv2D. Jeśli używasz C++, możesz spróbować:

# Assuming inputs: conv2d_node, node_map. 
filter_node_name = conv2d_node.input(1) 
filter_node = node_map[filter_node_name] 
# You might need to check identity node here. 
# Get the shape of filter_node using NodeDef.attr()