Chcę zreorganizować węzły modelu tensorflow .pb, więc najpierw pobieram NodeDef z GraphDef i otrzymuję attr, używając NodeDef.attr(). Dla węzła "Conv2D". Mogę uzyskać parametry takie jak kroki, padding, data_format, use_cudnn_on_gpu z attr, ale nie mogę uzyskać parametrów formatu wagi. Używany przeze mnie język to C++. Jak go zdobyć! Dziękuję Ci!Jak uzyskać format wag z modelu TensorFlow .pb?
8
A
Odpowiedz
4
Conv2D
posiada dwa wejścia: pierwszy jest dane, a drugi jest filter
(lub wagi), więc można po prostu sprawdzić format drugiego wejścia Conv2D
. Jeśli używasz C++, możesz spróbować:
# Assuming inputs: conv2d_node, node_map.
filter_node_name = conv2d_node.input(1)
filter_node = node_map[filter_node_name]
# You might need to check identity node here.
# Get the shape of filter_node using NodeDef.attr()