2016-04-17 45 views
8

Tam olarak python modelleri C++' da kullanılmak üzere nasıl dışa aktarılmalıdır?İhracat Python'dan Tensorflow grafikleri C++ 'da kullanım için

ben bu yazının benzer bir şey yapmaya çalışıyorum: Ben kuruluşundan biri yerine C++ API kendi TF modeli içe çalışıyorum https://www.tensorflow.org/versions/r0.8/tutorials/image_recognition/index.html

. Giriş boyutunu ve yolları değiştirdim, ancak garip hatalar ortaya çıkıyor. Bütün gün yığın taşması ve diğer forumları okudum ama boşuna harcadım.

Grafiği dışa aktarmak için iki yöntem denedim.

Yöntem 1: metagrafi.

...loading inputs, setting up the model, etc.... 

sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 


for i in range(num_steps): 
    x_batch, y_batch = batch(50) 
    if i%10 == 0: 
     train_accuracy = accuracy.eval(feed_dict={ 
     x:x_batch, y_: y_batch, keep_prob: 1.0}) 
     print("step %d, training accuracy %g"%(i, train_accuracy)) 
    train_step.run(feed_dict={x: x_batch, y_: y_batch, keep_prob: 0.5}) 

print("test accuracy %g"%accuracy.eval(feed_dict={ 
    x: features_test, y_: labels_test, keep_prob: 1.0})) 

saver = tf.train.Saver(tf.all_variables()) 
checkpoint = 
    '/home/sander/tensorflow/tensorflow/examples/cat_face/data/model.ckpt' 
    saver.save(sess, checkpoint) 

    tf.train.export_meta_graph(filename= 
    '/home/sander/tensorflow/tensorflow/examples/cat_face/data/cat_graph.pb', 
    meta_info_def=None, 
    graph_def=sess.graph_def, 
    saver_def=saver.restore(sess, checkpoint), 
    collection_list=None, as_text=False) 

Yöntem 1 programı çalıştırmak için çalışırken aşağıdaki hata verir:

Yöntem 2: write_graph:

tf.train.write_graph(sess.graph_def, 
'/home/sander/tensorflow/tensorflow/examples/cat_face/data/', 
'cat_graph.pb', as_text=False) 
Ayrıca grafik dışa yönelik başka bir metodu güvenilir
[libprotobuf ERROR 
google/protobuf/src/google/protobuf/wire_format_lite.cc:532] String field 
'tensorflow.NodeDef.op' contains invalid UTF-8 data when parsing a protocol 
buffer. Use the 'bytes' type if you intend to send raw bytes. 
E tensorflow/examples/cat_face/main.cc:281] Not found: Failed to load 
compute graph at 'tensorflow/examples/cat_face/data/cat_graph.pb' 

Bu sürüm aslında bir şeyler yüklemek gibi görünüyor, ancak değişkenler hakkında bir hata alıyorum başlatıldı: İlk başta

Running model failed: Failed precondition: Attempting to use uninitialized 
value weight1 
[[Node: weight1/read = Identity[T=DT_FLOAT, _class=["loc:@weight1"], 
_device="/job:localhost/replica:0/task:0/cpu:0"](weight1)]] 
+3

var. Bu Değişkenler kullanmak ve geri yükleme operasyonları çalıştırmak zorunda kalmadan kaçının - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py –

+0

Ah, bunu gördüm. Ama ben argümanlarını nasıl dolduracağımı öğrenmek için uğraşıyorum, tıpkı export_meta_graph'daki her argümana neyin gireceğini bilmiyorum gibi. Bunun için bazı örnek kodları biliyor musunuz? – Sander

+1

Burada bir örnek var: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph_test.py –

cevap

0

, sen

with tf.Session() as sess: 
//Build network here 
tf.train.write_graph(sess.graph.as_graph_def(), "C:\\output\\", "mymodel.pb") 

Sonra aşağıdaki komutu kullanarak dosyaya tanımını grafiğini

saver = tf.train.Saver(tf.global_variables()) 
saver.save(sess, "C:\\output\\mymodel.ckpt") 

Ardından koruyucu kullanarak modelinizi kaydetmeniz gerekir, size sahip olacaktır Çıktınızdaki 2 dosya, mymodel.ckpt, mymodel.pb

Download freeze_graph.py from here ve C: \ output \ komutunda aşağıdaki komutu çalıştırın. Sizin için farklıysa çıkış düğümü adını değiştirin.

piton freeze_graph.py --input_graph mymodel.pb --input_checkpoint mymodel.ckpt --output_node_names SoftMax/Reshape_1 --output_graph mymodelforc.pb

Sen C doğrudan mymodelforc.pb kullanabilirsiniz

sen

#include "tensorflow/core/public/session.h" 
#include "tensorflow/core/platform/env.h" 
#include "tensorflow/cc/ops/image_ops.h" 

Session* session; 
NewSession(SessionOptions(), &session); 

GraphDef graph_def; 
ReadBinaryProto(Env::Default(), "C:\\output\\mymodelforc.pb", &graph_def); 

session->Create(graph_def); 

Şimdi çıkarım için oturumu kullanabilirsiniz proto dosyasını yüklemek için aşağıdaki C kodu kullanabilirsiniz. Aşağıdaki şekilde

Sen çıkarsama parametresini uygulayabilirsiniz: ": Kullanım freeze_graph Yöntem 3"

// Same dimension and type as input of your network 
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, height, width, channel })); 
std::vector<tensorflow::Tensor> finalOutput; 

// Fill input tensor with your input data 

std::string InputName = "input"; // Your input placeholder's name 
std::string OutputName = "softmax/Reshape_1"; // Your output placeholder's name 

session->Run({ { InputName, input_tensor } }, { OutputName }, {}, &finalOutput); 

// finalOutput will contain the inference output that you search for 
+0

Deniz tensorflow versiyonu nedir? Sormamın nedeni 'saver.save' işlevinin kullanıyorum paketinde' .ckpt.meta' dosyası oluşturuyor gibi görünüyor. Sanırım saver.export_meta_graph'ın yarattığı aynı şey ... Son internet araştırması bunun R11 ile R12 arasında bir fark olduğunu gösteriyor gibi görünüyor, ama bunu çok yakın zamanda yazdınız ve hangi sürümü kullandığınızı merak ediyorum. – Geronimo

+0

Şu anki güncel tensorflow kaynağıyla bu kodu doğruladım. Ancak, C++ için "write_graph" kullanarak zaten ihracat yaptığınız gibi ckpt.meta dosyalarına gerek yoktur. Karışıklığı önlemek için kodu güncelleyeceğim –