2016-09-05 11 views
11

TensorFlow docs, TFRecordReader, TextLineReader, QueueRunner vb. Ve sıralarını kullanarak verileri okumak için bir dizi yolu açıklar.TensorFlow ağı, giriş üretmek için bir jeneratör kullanarak nasıl eğitilir?

Yapmak istediğim şey çok daha basit: Ben (X, y) tuples olarak sonsuz sayıda egzersiz verisi üreten bir python üreteci işlevim var (her ikisi de sayı dizisidir ve ilk boyut toplu iştir) boyut). Ben sadece bu verileri girdi olarak kullanan bir ağı eğitmek istiyorum.

Verileri üreten bir jeneratörü kullanarak TensorFlow ağını eğitmek için kendi başına yeterli bir örnek var mı?

def generator(data): 
    ... 
    yield (X, y) 

Artık modeli mimarisini tarif etmektedir başka bir fonksiyon gerekir: (MNIST veya cifar örnekler çizgisinde)

+2

Sizin durumunuz için yararlı olabilecek ['tf.data.Dataset.from_generator'] (https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator) var. – Jakub

cevap

15

verileri üreten bir işlev olduğunu varsayalım. X'i işleyen ve y'yi çıktı olarak (yani, sinir ağı) öngörmek zorunda olan herhangi bir işlev olabilir.

işlevinizi varsayalım, girdi olarak X ve Y kabul bir şekilde X, y için bir tahmin hesaplar ve y arasındaki (regresyon durumunda, örneğin çapraz entropi veya MSE) kayıp fonksiyonunu verir ve Y tahmin:

X = tf.placeholder(tf.float32, shape=(batch_size, x_dim)) 
y = tf.placeholder(tf.float32, shape=(batch_size, y_dim)) 

yer tutucular Somet şunlardır:

def neural_network(X, y): 
    # computation of prediction for y using X 
    ... 
    return loss(y, y_pred) 

modeliniz işi yapmak için, X ve y her ikisi için yer tutucular tanımlamak ve sonra bir oturum çalıştırmak için gereken feed_dict tarafından oturumu çalıştırırken, belirtmeniz gereken "serbest değişken" gibi hing:

with tf.Session() as sess: 
    # variables need to be initialized before any sess.run() calls 
    tf.global_variables_initializer().run() 

    for X_batch, y_batch in generator(data): 
     feed_dict = {X: X_batch, y: y_batch} 
     _, loss_value, ... = sess.run([train_op, loss, ...], feed_dict) 
     # train_op here stands for optimization operation you have defined 
     # and loss for loss function (return value of neural_network function) 

bunu yararlı bulacağını umuyorum. Bununla birlikte, bunun tam olarak bir uygulama değil, neredeyse hiçbir ayrıntı belirtmediğiniz için bir sahte kod olduğunu unutmayın.

+0

Bir sonraki işlevi manuel olarak kodlamak yerine, jeneratör fonksiyonunu bir tahmin ediciye geçirmenin bir yolu var mı? – skadoosh

+0

@skadoosh Keras'ı kullanmayı düşünmelisiniz. –

+2

@skadoosh - tensorflow kullanmak istiyorsanız, sürüm 1.6 'tf.data.Dataset.from_generator' ile yapılabilen bir' tf.data.Dataset'ı kabul etmek için 'Estimator.train''na izin verir. – Jakub