Browse Source

#144

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
f516966f34
4 changed files with 12 additions and 14 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Sessions/Session.cs
  2. +6
    -8
      test/TensorFlowNET.UnitTest/CApiGradientsTest.cs
  3. +3
    -3
      test/TensorFlowNET.UnitTest/CSession.cs
  4. +2
    -2
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -47,8 +47,8 @@ namespace Tensorflow
public void Dispose() public void Dispose()
{ {
Options.Dispose(); Options.Dispose();
Status.Dispose();
c_api.TF_DeleteSession(_handle, Status); c_api.TF_DeleteSession(_handle, Status);
Status.Dispose();
} }


public void __enter__() public void __enter__()


+ 6
- 8
test/TensorFlowNET.UnitTest/CApiGradientsTest.cs View File

@@ -31,16 +31,16 @@ namespace TensorFlowNET.UnitTest
BuildSuccessGraph(inputs, outputs); BuildSuccessGraph(inputs, outputs);
BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);


AddGradients(grad_inputs_provided, string.Empty, inputs, 2, outputs, 1,
AddGradients(grad_inputs_provided, "test", inputs, 2, outputs, 1,
grad_outputs); grad_outputs);
// EXPECT_EQ(TF_OK, TF_GetCode(s_));
EXPECT_EQ(TF_OK, TF_GetCode(s_));


// Compare that the graphs match. // Compare that the graphs match.
GraphDef expected_gdef; GraphDef expected_gdef;
GraphDef gdef; GraphDef gdef;
EXPECT_TRUE(GetGraphDef(expected_graph_, out expected_gdef)); EXPECT_TRUE(GetGraphDef(expected_graph_, out expected_gdef));
EXPECT_TRUE(GetGraphDef(graph_, out gdef)); EXPECT_TRUE(GetGraphDef(graph_, out gdef));
//TF_EXPECT_GRAPH_EQ(expected_gdef, gdef);
// Assert.IsTrue(expected_gdef.ToString().Equals(gdef.ToString()));


// Compare that the output of the gradients of both graphs match. // Compare that the output of the gradients of both graphs match.
RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs); RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs);
@@ -65,16 +65,14 @@ namespace TensorFlowNET.UnitTest
var csession = new CSession(graph_, s_); var csession = new CSession(graph_, s_);
var expected_csession = new CSession(expected_graph_, s_); var expected_csession = new CSession(expected_graph_, s_);


var grad_outputs_vec = new List<IntPtr>();
grad_outputs_vec.AddRange(grad_outputs.Select(x => x.oper));
var grad_outputs_vec = grad_outputs;
csession.SetOutputs(grad_outputs_vec); csession.SetOutputs(grad_outputs_vec);
csession.Run(s_); csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)); ASSERT_EQ(TF_OK, TF_GetCode(s_));
var out0 = csession.output_tensor(0); var out0 = csession.output_tensor(0);
var out1 = csession.output_tensor(1); var out1 = csession.output_tensor(1);


var expected_grad_outputs_vec = new List<IntPtr>();
expected_grad_outputs_vec.AddRange(expected_grad_outputs.Select(x => x.oper));
var expected_grad_outputs_vec = expected_grad_outputs;
expected_csession.SetOutputs(expected_grad_outputs_vec); expected_csession.SetOutputs(expected_grad_outputs_vec);
expected_csession.Run(s_); expected_csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)); ASSERT_EQ(TF_OK, TF_GetCode(s_));
@@ -197,7 +195,7 @@ namespace TensorFlowNET.UnitTest
var matmul2 = MatMul(expected_graph_, s_, const0, const3, var matmul2 = MatMul(expected_graph_, s_, const0, const3,
"gradients/MatMul_1", true, false); "gradients/MatMul_1", true, false);
expected_grad_outputs[0] = new TF_Output(matmul1, 0); expected_grad_outputs[0] = new TF_Output(matmul1, 0);
expected_grad_outputs[1] = new TF_Output( matmul2, 0);
expected_grad_outputs[1] = new TF_Output(matmul2, 0);
} }


private Operation OnesLike(Graph graph, Status s, Operation input, string name) private Operation OnesLike(Graph graph, Status s, Operation input, string name)


+ 3
- 3
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -49,13 +49,13 @@ namespace TensorFlowNET.UnitTest
input_values_.Clear(); input_values_.Clear();
} }


public void SetOutputs(List<IntPtr> outputs)
public void SetOutputs(TF_Output[] outputs)
{ {
ResetOutputValues(); ResetOutputValues();
outputs_.Clear(); outputs_.Clear();
foreach (var output in outputs) foreach (var output in outputs)
{ {
outputs_.Add(new TF_Output(output, 0));
outputs_.Add(output);
output_values_.Add(IntPtr.Zero); output_values_.Add(IntPtr.Zero);
} }
} }
@@ -75,7 +75,7 @@ namespace TensorFlowNET.UnitTest
var inputs_ptr = inputs_.ToArray(); var inputs_ptr = inputs_.ToArray();
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
var outputs_ptr = outputs_.ToArray(); var outputs_ptr = outputs_.ToArray();
var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray();
var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
IntPtr[] targets_ptr = new IntPtr[0]; IntPtr[] targets_ptr = new IntPtr[0];


c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,


+ 2
- 2
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -37,7 +37,7 @@ namespace TensorFlowNET.UnitTest
inputs.Add(feed, new Tensor(3)); inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs); csession.SetInputs(inputs);


var outputs = new List<IntPtr> { add };
var outputs = new TF_Output[] { new TF_Output(add, 0) };
csession.SetOutputs(outputs); csession.SetOutputs(outputs);


csession.Run(s); csession.Run(s);
@@ -56,7 +56,7 @@ namespace TensorFlowNET.UnitTest
inputs = new Dictionary<Operation, Tensor>(); inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(7)); inputs.Add(feed, new Tensor(7));
csession.SetInputs(inputs); csession.SetInputs(inputs);
outputs = new List<IntPtr> { neg };
outputs = new TF_Output[] { new TF_Output(neg, 0) };
csession.SetOutputs(outputs); csession.SetOutputs(outputs);
csession.Run(s); csession.Run(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);


Loading…
Cancel
Save