5

Tensorflow versiyonum 0.11. Eğitimden sonra bir grafik kaydetmek veya tensorflow'un yükleyebileceği başka bir şey kaydetmek istiyorum.Tensorflow'da eğitildikten sonra modeli nasıl kullanırsınız (kaydet/yükle grafiği)

I/Zaten bu yazıyı okumak ihraç kullanma ve MetaGraph

İçe: Tensorflow: how to save/restore a model?

Benim Save.py dosyası:

X = tf.placeholder("float", [None, 28, 28, 1], name='X') 
Y = tf.placeholder("float", [None, 10], name='Y') 

tf.train.Saver() 
with tf.Session() as sess: 
    ...run something ... 
    final_tensor = tf.nn.softmax(py_x, name='final_result') 
    tf.add_to_collection("final_tensor", final_tensor) 

    predict_op = tf.argmax(py_x, 1) 
    tf.add_to_collection("predict_op", predict_op) 

saver.save(sess, 'my_project') 

Sonra çalıştırmak load.py:

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('my_project.meta') 
    new_saver.restore(sess, 'my_project') 
    predict_op = tf.get_collection("predict_op")[0] 
    for i in range(2): 
     test_indices = np.arange(len(teX)) # Get A Test Batch 
     np.random.shuffle(test_indices) 
     test_indices = test_indices[0:test_size] 

     print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
         sess.run(predict_op, feed_dict={"X:0": teX[test_indices], 
                 "p_keep_conv:0": 1.0, 
                 "p_keep_hidden:0": 1.0}))) 

ama hata dönmek

Traceback (most recent call last): 
    File "load_05_convolution.py", line 62, in <module> 
    "p_keep_hidden:0": 1.0}))) 
    File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run 
    run_metadata_ptr) 
    File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 894, in _run 
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) 
ValueError: Cannot feed value of shape (256, 784) for Tensor u'X:0', which has shape '(?, 28, 28, 1)' 

Gerçekten neden bilmiyorum? o tf.add_to_collection sadece tek bir yer tutucu içerdiğinden

Traceback (most recent call last): 
    File "load_05_convolution.py", line 46, in <module> 
    final_tensor = tf.get_collection("final_result")[0] 
IndexError: list index out of range 

mi: Ben final_tensor = tf.get_collection("final_result")[0]

Başka hata döndürür eklerseniz

? tf.train.write_graph

kullanılarak

II

/I başarıyla dosya 'train.pb'

oluşturulan tf.train.write_graph(graph, 'folder', 'train.pb')

save.py sonuna bu satırı ekleyin Benim load.py:

with tf.gfile.FastGFile('folder/train.pb', 'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    _ = tf.import_graph_def(graph_def, name='') 

with tf.Session() as sess: 
    predict_op = sess.graph.get_tensor_by_name('predict_op:0') 
    for i in range(2): 
     test_indices = np.arange(len(teX)) # Get A Test Batch 
     np.random.shuffle(test_indices) 
     test_indices = test_indices[0:test_size] 

     print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
         sess.run(predict_op, feed_dict={"X:0": teX[test_indices], 
                 "p_keep_conv:0": 1.0, 
                 "p_keep_hidden:0": 1.0}))) 

Ardından hata döndürür:

Traceback (most recent call last): 
    File "load_05_convolution.py", line 22, in <module> 
    graph_def.ParseFromString(f.read()) 
    File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString 
    self.MergeFromString(serialized) 
    File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1085, in MergeFromString 
    raise message_mod.DecodeError('Unexpected end-group tag.') 
google.protobuf.message.DecodeError: Unexpected end-group tag. 

Kaydetme/yükleme modeli için standart yolu, kodu veya öğreticiyi paylaşmayı düşünür müsünüz? Gerçekten kafam karıştı. (MetaGraph kullanarak)

+0

mi ?? 'new_saver.restore (oturum, 'my_projec')' load.py içinde Yolu doğru şekilde kontrol edin. –

+0

üzgünüm. Yazarken sadece bir hata. Yükte.py 'tich_chap' isimleri ama 'projeye' değişti daha kolay anlaşılır –

+0

@AayushKumarSingha, herhangi bir fikriniz var mı –

cevap

2

İlk çözüm neredeyse çalışır, ancak bir parti olan 4 D tensörü olarak MNIST eğitim örneklerinden toplu beklediği bir tf.placeholder() için MNIST eğitim örneklerini düzleştirilmiş besleyen çünkü hata ortaya çıkar Şekil batch_size x (= 28) x width (= 28) x Bunu çözmenin en kolay yolu, giriş verilerinizi yeniden şekillendirmektir. Bunun yerine bu ifadenin:

print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
       sess.run(predict_op, feed_dict={ 
        "X:0": teX[test_indices], 
        "p_keep_conv:0": 1.0, 
        "p_keep_hidden:0": 1.0}))) 

... yerine, uygun şekilde giriş verilerini yeniden şekillendirir aşağıdaki deyimi deneyin: Bu bir yazım hatası

print(i, np.mean(np.argmax(teY[test_indices], axis=1) == 
       sess.run(predict_op, feed_dict={ 
        "X:0": teX[test_indices].reshape(-1, 28, 28, 1), 
        "p_keep_conv:0": 1.0, 
        "p_keep_hidden:0": 1.0}))) 
+0

Gerçekten işe yaramıyor. –

+0

@ZHANGJuenjie Daha spesifik olabilir misiniz? Aynı kodu çalıştırmaya ve bir hata vermeye mi çalışıyorsunuz? Eğer öyleyse, hangisi? – mrry