java label api_使用python API进行的培训作为Java API中LabelImage模块的输入?
我有java tensorflow API的問題.我使用python tensorflow API運行訓練,生成文件output_graph.pb和output_labels.txt.現在出于某種原因,我想使用這些文件作為java tensorflow API中LabelImage模塊的輸入.我認為一切都會正常工作,因為該模塊只需要一個.pb和一個.txt.不過,當我運行模塊時,我收到此錯誤:
2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:343)
at org.tensorflow.Session$Runner.feed(Session.java:137)
at org.tensorflow.Session$Runner.feed(Session.java:126)
at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115)
at it.zero11.LabelImage.main(LabelImage.java:68)
如果你幫助我找到問題所在,我將非常感激.此外,我想問你是否有辦法從java tensorflow API運行培訓,因為這會使事情變得更容易.
更確切地說:
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
解決方法:
LabelImage.java中默認使用的模型與正在重新訓練的模型不同,因此輸入和輸出節點的名稱不對齊.請注意,TensorFlow模型是圖形,feed()和fetch()的參數是圖形中節點的名稱.因此,您需要知道適合您的模型的名稱.
看一下retrain.py,它似乎有一個節點,它將JPEG文件的原始內容作為輸入(節點DecodeJpeg/contents),并在節點final_result中生成標簽集.
如果是這種情況,那么你會在Java中執行類似下面的操作(并且您不需要構造圖形來對圖像進行標準化,因為這似乎是重新訓練模型的一部分,因此將LabelImage.java:64替換為某些內容喜歡:
try (Tensor image = Tensor.create(imageBytes);
Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
// Note the change to the name of the node and the fact
// that it is being provided the raw imageBytes as input
Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
float[] probabilities = result.copyTo(new float[1][nlabels])[0];
// At this point nlabels = number of classes in your retrained model
DoSomethingWith(probabilities);
}
}
希望有所幫助.
標簽:image-recognition,java,python,tensorflow
來源: https://codeday.me/bug/20191007/1864494.html
總結
以上是生活随笔為你收集整理的java label api_使用python API进行的培训作为Java API中LabelImage模块的输入?的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: java静态方法声明_方法本地类中的Ja
- 下一篇: php crypt加密 盐值,PHP c