Browse Source

move import graph to ImportGraph.

tags/v0.10
Oceania2018 6 years ago
parent
commit
2b630fb190
1 changed files with 28 additions and 36 deletions
  1. +28
    -36
      test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs

+ 28
- 36
test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs View File

@@ -32,7 +32,7 @@ namespace TensorFlowNET.Examples
{
public bool Enabled { get; set; } = true;
public string Name => "Object Detection";
public bool IsImportingGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = true;

public float MIN_SCORE = 0.5f;

@@ -42,16 +42,34 @@ namespace TensorFlowNET.Examples
string labelFile = "mscoco_label_map.pbtxt";
string picFile = "input.jpg";

NDArray imgArr;

public bool Run()
{
PrepareData();

// read in the input image
var imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg"));
imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg"));

var graph = IsImportingGraph ? ImportGraph() : BuildGraph();

with(tf.Session(graph), sess => Predict(sess));

var graph = new Graph().as_default();
return true;
}

public Graph ImportGraph()
{
var graph = new Graph().as_default();
graph.Import(Path.Join(modelDir, pbFile));

return graph;
}

public void Predict(Session sess)
{
var graph = tf.get_default_graph();

Tensor tensorNum = graph.OperationByName("num_detections");
Tensor tensorBoxes = graph.OperationByName("detection_boxes");
Tensor tensorScores = graph.OperationByName("detection_scores");
@@ -59,16 +77,11 @@ namespace TensorFlowNET.Examples
Tensor imgTensor = graph.OperationByName("image_tensor");
Tensor[] outTensorArr = new Tensor[] { tensorNum, tensorBoxes, tensorScores, tensorClasses };

with(tf.Session(graph), sess =>
{
var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr));
NDArray[] resultArr = results.Data<NDArray>();
buildOutputImage(resultArr);
});
var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr));

return true;
NDArray[] resultArr = results.Data<NDArray>();

buildOutputImage(resultArr);
}

public void PrepareData()
@@ -159,29 +172,8 @@ namespace TensorFlowNET.Examples
}
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public void Train(Session sess)
{
throw new NotImplementedException();
}

public void Predict(Session sess)
{
throw new NotImplementedException();
}

public void Test(Session sess)
{
throw new NotImplementedException();
}
public Graph BuildGraph() => throw new NotImplementedException();
public void Train(Session sess) => throw new NotImplementedException();
public void Test(Session sess) => throw new NotImplementedException();
}
}

Loading…
Cancel
Save