2015-12-18 12 views
9

Mam trzy tensorów A, B and C w tensorflow, A i B są zarówno kształtu (m, n, r), C jest binarnym tensor kształcie (m, n, 1).Jak jawnie rozgłaszać tensor, aby dopasować kształt innej osoby w tensorflow?

Chcę wybrać elementy z A lub B na podstawie wartości C. Oczywistym narzędziem jest tf.select, jednak nie ma to semantyki nadawania, więc najpierw muszę jawnie nadawać C do tego samego kształtu co A i B.

To byłaby moja pierwsza próba jak to zrobić, ale nie robi tego lubię mieszanie tensora (tf.shape(A)[2]) do listy kształtów.

import tensorflow as tf 
A = tf.random_normal([20, 100, 10]) 
B = tf.random_normal([20, 100, 10]) 
C = tf.random_normal([20, 100, 1]) 
C = tf.greater_equal(C, tf.zeros_like(C)) 

C = tf.tile(C, [1,1,tf.shape(A)[2]]) 
D = tf.select(C, A, B) 

Jakie jest właściwe podejście?

+2

One hack, który działa: można użyć semantykę nadawania * * wielowarstwowego i pomnożyć przez te tensora tak:' Expander = tf.ones_like (B) ', następnie 'C = Expander * C' – wxs

Odpowiedz

9

EDYCJA: We wszystkich wersjach TensorFlow od 0.12rc0 kod w pytaniu działa bezpośrednio. TensorFlow automatycznie rozmieści tensory i liczby Pythona w tensorze. Poniższe rozwiązanie z użyciem tf.pack() jest potrzebne tylko w wersjach starszych niż 0.12rc0. Zauważ, że zmieniono nazwę na tf.pack() na tf.stack() w TensorFlow 1.0.


Twoje rozwiązanie jest bardzo bliskie. należy wymienić linię:

C = tf.tile(C, [1,1,tf.shape(C)[2]]) 

... z następujących czynności:

C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]])) 

(Przyczyną problemu jest to, że nie będzie TensorFlow niejawnie przekonwertować listę tensorów i literały Python w tensor. tf.pack() pobiera listę tensorów, więc będzie konwertować każdy z elementów w jego wejścia (1, 1 i tf.shape(C)[2]) do tensora. Ponieważ każdy element jest skalarne, wynik będzie wektorem.)

+1

Myślę, że masz dodatkowe' [i brakuje ')', ale wtedy pojawia się nieco tajemniczy błąd, gdy * uruchamiam * sesję tf: 'InvalidArgumentError: Wejścia do operacji Select_13 z typ Wybierz musi mieć ten sam rozmiar i kształt. Wejście 0: dim {rozmiar: 20} dim {rozmiar: 100} dim {size: 1}! = Input 1: dim {rozmiar: 20} dim {size: 100} dim {size: 10} ' – wxs

+0

Dobra uwaga, i Zaktualizowałem odpowiedź - również argumentem 'tf.shape()' powinno być 'A' (lub' B'). To działa dla mnie - jaki błąd widzisz? – mrry

+0

Tak, naprawiono to teraz :) nie zauważyłem niepoprawnego parametru do 'tf.shape()'. Dzięki! – wxs

2

Tutaj „Sa brudne Hack:

import tensorflow as tf 

def broadcast(tensor, shape): 
    return tensor + tf.zeros(shape, dtype=tensor.dtype) 

A = tf.random_normal([20, 100, 10]) 
B = tf.random_normal([20, 100, 10]) 
C = tf.random_normal([20, 100, 1]) 

C = broadcast(C, A.shape) 
D = tf.select(C, A, B) 
0
import tensorflow as tf 

def broadcast(tensor, shape): 
    """Broadcasts ``x`` to have shape ``shape``. 
                    | 
    Uses ``tf.Assert`` statements to ensure that the broadcast is 
    valid. 

    First calculates the number of missing dimensions in 
    ``tf.shape(x)`` and left-pads the shape of ``x`` with that many 
    ones. Then identifies the dimensions of ``x`` that require 
    tiling and tiles those dimensions appropriately. 

    Args: 
     x (tf.Tensor): The tensor to broadcast. 
     shape (Union[tf.TensorShape, tf.Tensor, Sequence[int]]): 
      The shape to broadcast to. 

    Returns: 
     tf.Tensor: ``x``, reshaped and tiled to have shape ``shape``. 

    """ 
    with tf.name_scope('broadcast') as scope: 
     shape_x = tf.shape(x) 
     rank_x = tf.shape(shape0)[0] 
     shape_t = tf.convert_to_tensor(shape, preferred_dtype=tf.int32) 
     rank_t = tf.shape(shape1)[0] 

     with tf.control_dependencies([tf.Assert(
      rank_t >= rank_x, 
      ['len(shape) must be >= tf.rank(x)', shape_x, shape_t], 
      summarize=255 
     )]): 
      missing_dims = tf.ones(tf.stack([rank_t - rank_x], 0), tf.int32) 

     shape_x_ = tf.concat([missing_dims, shape_x], 0) 
     should_tile = tf.equal(shape_x_, 1) 

     with tf.control_dependencies([tf.Assert(
      tf.reduce_all(tf.logical_or(tf.equal(shape_x_, shape_t), should_tile), 
      ['cannot broadcast shapes', shape_x, shape_t], 
      summarize=255 
     )]): 
      multiples = tf.where(should_tile, shape_t, tf.ones_like(shape_t)) 
      out = tf.tile(tf.reshape(x, shape_x_), multiples, name=scope) 

     try: 
      out.set_shape(shape) 
     except: 
      pass 

     return out 

A = tf.random_normal([20, 100, 10]) 
B = tf.random_normal([20, 100, 10]) 
C = tf.random_normal([20, 100, 1]) 

C = broadcast(C, A.shape) 
D = tf.select(C, A, B)