2015-12-27 8 views
15

Próbuję użyć tensorflow do nauki transferu. Pobrałem wstępnie opracowany model inception3 z samouczka. W kodzie, dla przewidywania:Karmienie danych obrazu w tensorflow do nauki transferu

prediction = sess.run(softmax_tensor,{'DecodeJpeg/contents:0'}:image_data}) 

Czy istnieje sposób na zasilenie obrazu png. Próbowałem zmienić DecodeJpeg na DecodePng, ale nie zadziałało. Poza tym, co powinienem zmienić, jeśli chcę podawać zdekodowany plik obrazu jak tablicę numpy lub partię tablic?

Dzięki!

Odpowiedz

27

Dostarczony wykres InceptionV3 używany w classify_image.py obsługuje wyłącznie obrazy JPEG gotowe do użycia. Istnieją dwa sposoby, można użyć tego wykresu z obrazów PNG:

  1. przekonwertować obraz PNG do height x width x 3 (channels) Numpy tablicą, na przykład za pomocą PIL, następnie karmić 'DecodeJpeg:0' tensora:

    import numpy as np 
    from PIL import Image 
    # ... 
    
    image = Image.open("example.png") 
    image_array = np.array(image)[:, :, 0:3] # Select RGB channels only. 
    
    prediction = sess.run(softmax_tensor, {'DecodeJpeg:0': image_array}) 
    

    może złudzenia, 'DecodeJpeg:0' jest wyjście z DecodeJpeg op, więc karmiąc ten tensor, jesteś w stanie karmić surowych danych obrazu.

  2. Dodaj do importowanego wykresu opcję tf.image.decode_png(). Po prostu zmiana nazwy podawanego tensora z 'DecodeJpeg/contents:0' na 'DecodePng/contents:0' nie działa, ponieważ nie ma 'DecodePng' op w dostarczonym grafie. Możesz dodać taki węzeł na wykresie za pomocą input_map argument tf.import_graph_def():

    png_data = tf.placeholder(tf.string, shape=[]) 
    decoded_png = tf.image.decode_png(png_data, channels=3) 
    # ... 
    
    graph_def = ... 
    softmax_tensor = tf.import_graph_def(
        graph_def, 
        input_map={'DecodeJpeg:0': decoded_png}, 
        return_elements=['softmax:0']) 
    
    sess.run(softmax_tensor, {png_data: ...}) 
    
+0

Próbowałem z pierwszej metody. Wypisuje "Nie można pobrać elementu z kanału". Nie jestem pewien, dlaczego. Ale twoja druga praca. Dzięki!! –

+0

Hmm, błąd "Nie można pobrać elementu z kanału" jest dziwny - oznacza to, że 'image_array' traktowany jest jako tablica łańcuchów, więc może być coś czego brakuje w konwersji typu twojego obrazu na tensor TensorFlow. – mrry

+2

Z drugiej odpowiedzi zakładam, że "DecodeJpeg: 0" jest konstruowane z 'jpg_data = tf.placeholder (tf.string, shape = []); decoded_jpg = tf.image.decode_jpeg (jpg_data, channels = 3) 'Czeka na ciąg znaków zamiast tablicy numpy. –

1

Poniższy kod powinien obsługiwać obu przypadkach.

import numpy as np 
from PIL import Image 

image_file = 'test.jpeg' 
with tf.Session() as sess: 

    #  softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') 
    if image_file.lower().endswith('.jpeg'): 
     image_data = tf.gfile.FastGFile(image_file, 'rb').read() 
     prediction = sess.run('final_result:0', {'DecodeJpeg/contents:0': image_data}) 
    elif image_file.lower().endswith('.png'): 
     image = Image.open(image_file) 
     image_array = np.array(image)[:, :, 0:3] 
     prediction = sess.run('final_result:0', {'DecodeJpeg:0': image_array}) 

    prediction = prediction[0]  
    print(prediction) 

lub krótsza wersja z bezpośrednich strun:

image_file = 'test.png' # or 'test.jpeg' 
image_data = tf.gfile.FastGFile(image_file, 'rb').read() 
ph = tf.placeholder(tf.string, shape=[]) 

with tf.Session() as sess:   
    predictions = sess.run(output_layer_name, {ph: image_data})