diff --git a/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs b/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs index 32f7841b..9c4760d7 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs @@ -32,7 +32,7 @@ namespace TensorFlowNET.Examples { public bool Enabled { get; set; } = true; public string Name => "Object Detection"; - public bool IsImportingGraph { get; set; } = false; + public bool IsImportingGraph { get; set; } = true; public float MIN_SCORE = 0.5f; @@ -42,16 +42,34 @@ namespace TensorFlowNET.Examples string labelFile = "mscoco_label_map.pbtxt"; string picFile = "input.jpg"; + NDArray imgArr; + public bool Run() { PrepareData(); // read in the input image - var imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg")); + imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg")); + + var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); + + with(tf.Session(graph), sess => Predict(sess)); - var graph = new Graph().as_default(); + return true; + } + + public Graph ImportGraph() + { + var graph = new Graph().as_default(); graph.Import(Path.Join(modelDir, pbFile)); + return graph; + } + + public void Predict(Session sess) + { + var graph = tf.get_default_graph(); + Tensor tensorNum = graph.OperationByName("num_detections"); Tensor tensorBoxes = graph.OperationByName("detection_boxes"); Tensor tensorScores = graph.OperationByName("detection_scores"); @@ -59,16 +77,11 @@ namespace TensorFlowNET.Examples Tensor imgTensor = graph.OperationByName("image_tensor"); Tensor[] outTensorArr = new Tensor[] { tensorNum, tensorBoxes, tensorScores, tensorClasses }; - with(tf.Session(graph), sess => - { - var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr)); - - NDArray[] resultArr = results.Data(); - - buildOutputImage(resultArr); - }); + var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr)); - return true; + NDArray[] resultArr = results.Data(); + + buildOutputImage(resultArr); } public void PrepareData() @@ -159,29 +172,8 @@ namespace TensorFlowNET.Examples } } - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } + public Graph BuildGraph() => throw new NotImplementedException(); + public void Train(Session sess) => throw new NotImplementedException(); + public void Test(Session sess) => throw new NotImplementedException(); } }