2016-08-01 40 views
5

Ben şu var basit yer tutucular:Tensorflow'da tf.cond içindeki işlevlere parmetreler nasıl aktarılır?

x = tf.placeholder(tf.float32, shape=[1]) 
y = tf.placeholder(tf.float32, shape=[1]) 
z = tf.placeholder(tf.float32, shape=[1]) 

olarak tanımlanan iki işlev f1 ve f2 vardır:

pred = tf.placeholder(tf.bool, shape=[1]) 
result = tf.cond(pred, f1(x,y), f2(y,z)) 

:

def fn1(a, b): 
    return tf.mul(a, b) 
def fn2(a, b): 
    return tf.add(a, b) 

Şimdi beklenen durumuna göre sonucu hesaplamak istiyorum Ama bana fn1 and fn2 must be callable diyerek bir hata veriyor.

Çalışma zamanında parametreleri alabilmeleri için fn1 ve fn2'u nasıl yazabilirim? Aşağıdaki çağırmak istiyorum:

sess.run(result, feed_dict={x:1,y:2,z:3,pred:True}) 

cevap

1

en kolay çağrısında Fonksiyonlarınızı tanımlamak olacaktır: Sen lambda kullanarak işlevlerine parametreler iletebilirsiniz

result = tf.cond(pred, lambda: tf.mul(a, b), lambda: tf.add(a, b)) 
9

ve kod gibidir körük.

x = tf.placeholder(tf.float32) 
y = tf.placeholder(tf.float32) 
z = tf.placeholder(tf.float32) 

def fn1(a, b): 
    return tf.mul(a, b) 

def fn2(a, b): 
    return tf.add(a, b) 

pred = tf.placeholder(tf.bool) 
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) 

Sonra bellowing olarak arayabilirsiniz:

with tf.Session() as sess: 
    print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True}) 
    # The result is 2.0