Browse Source

Completed `TEST(CAPI, Graph)` porting

tags/v0.1.0-Tensor
Oceania2018 7 years ago
parent
commit
64cc96bc69
12 changed files with 241 additions and 24 deletions
  1. +0
    -6
      TensorFlow.NET.sln
  2. +19
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  3. +2
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +26
    -0
      src/TensorFlowNET.Core/Operations/OperationDescription.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs
  6. +1
    -5
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  7. +1
    -0
      src/TensorFlowNET.Core/c_api.cs
  8. +1
    -2
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  9. +81
    -4
      test/TensorFlowNET.UnitTest/GraphTest.cs
  10. +1
    -1
      test/TensorFlowNET.UnitTest/StatusTest.cs
  11. +1
    -2
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  12. +106
    -3
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 0
- 6
TensorFlow.NET.sln View File

@@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DA680126-DA60-4CE3-9094-72C355C081D3}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -29,10 +27,6 @@ Global
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU
{DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 19
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -31,6 +31,25 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);

/// <summary>
/// Iterate through the operations of a graph.
/// </summary>
/// <param name="graph"></param>
/// <param name="pos"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos);

/// <summary>
/// Returns the operation in the graph with `oper_name`. Returns nullptr if
/// no operation found.
/// </summary>
/// <param name="graph"></param>
/// <param name="oper_name"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name);

/// <summary>
/// Sets the shape of the Tensor referenced by `output` in `graph` to
/// the shape described by `dims` and `num_dims`.


+ 2
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -116,6 +116,8 @@ namespace Tensorflow
{
case IntPtr val:
return val == _handle;
case Operation val:
return val._handle == _handle;
}

return base.Equals(obj);


+ 26
- 0
src/TensorFlowNET.Core/Operations/OperationDescription.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class OperationDescription
{
private IntPtr _handle;

public OperationDescription(IntPtr handle)
{
_handle = handle;
}

public static implicit operator OperationDescription(IntPtr handle)
{
return new OperationDescription(handle);
}

public static implicit operator IntPtr(OperationDescription desc)
{
return desc._handle;
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs View File

@@ -9,6 +9,7 @@ namespace Tensorflow
public struct TF_OperationDescription
{
public IntPtr node_builder;
//public TF_Graph graph;
public IntPtr graph;
public IntPtr colocation_constraints;
}
}

+ 1
- 5
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -33,11 +33,7 @@

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.6.1" />
<PackageReference Include="NumSharp" Version="0.6.2" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<PackageReference Include="NumSharp" Version="0.6.3" />
</ItemGroup>

<ItemGroup>


+ 1
- 0
src/TensorFlowNET.Core/c_api.cs View File

@@ -22,6 +22,7 @@ namespace Tensorflow
/// int32_t => int
/// int64_t* => long[]
/// size_t* => unlong[]
/// size_t* => ref uint
/// void* => IntPtr
/// string => IntPtr c_api.StringPiece(IntPtr)
/// </summary>


+ 1
- 2
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -6,11 +6,10 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="NumSharp" Version="0.6.2" />
<PackageReference Include="NumSharp" Version="0.6.3" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 81
- 4
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -104,21 +104,98 @@ namespace TensorFlowNET.UnitTest
Assert.IsFalse(found_placeholder);
found_placeholder = true;
}
/*else if (IsScalarConst(n, 3))
else if (c_test_util.IsScalarConst(n, 3))
{
Assert.IsFalse(found_scalar_const);
found_scalar_const = true;
}
else if (IsAddN(n, 2))
else if (c_test_util.IsAddN(n, 2))
{
Assert.IsFalse(found_add);
found_add = true;
}
else
{
ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n);
}*/
Assert.Fail($"Unexpected NodeDef: {n}");
}
}
Assert.IsTrue(found_placeholder);
Assert.IsTrue(found_scalar_const);
Assert.IsTrue(found_add);

// Add another oper to the graph.
var neg = c_test_util.Neg(add, graph, s);
Assert.AreEqual(TF_Code.TF_OK, s.Code);

// Serialize to NodeDef.
var node_def = c_test_util.GetNodeDef(neg);

// Validate NodeDef is what we expect.
Assert.IsTrue(c_test_util.IsNeg(node_def, "add"));

// Serialize to GraphDef.
var graph_def2 = c_test_util.GetGraphDef(graph);

// Compare with first GraphDef + added NodeDef.
graph_def.Node.Add(node_def);
Assert.AreEqual(graph_def.ToString(), graph_def2.ToString());

// Look up some nodes by name.
Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
Assert.AreEqual(neg, neg2);
var node_def2 = c_test_util.GetNodeDef(neg2);
Assert.AreEqual(node_def.ToString(), node_def2.ToString());

Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
Assert.AreEqual(feed, feed2);
node_def = c_test_util.GetNodeDef(feed);
node_def2 = c_test_util.GetNodeDef(feed2);
Assert.AreEqual(node_def.ToString(), node_def2.ToString());

// Test iterating through the nodes of a graph.
found_placeholder = false;
found_scalar_const = false;
found_add = false;
bool found_neg = false;
uint pos = 0;
Operation oper;

while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
{
if (oper.Equals(feed))
{
Assert.IsFalse(found_placeholder);
found_placeholder = true;
}
else if (oper.Equals(three))
{
Assert.IsFalse(found_scalar_const);
found_scalar_const = true;
}
else if (oper.Equals(add))
{
Assert.IsFalse(found_add);
found_add = true;
}
else if (oper.Equals(neg))
{
Assert.IsFalse(found_neg);
found_neg = true;
}
else
{
node_def = c_test_util.GetNodeDef(oper);
Assert.Fail($"Unexpected Node: {node_def.ToString()}");
}
}

Assert.IsTrue(found_placeholder);
Assert.IsTrue(found_scalar_const);
Assert.IsTrue(found_add);
Assert.IsTrue(found_neg);

graph.Dispose();
s.Dispose();
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/StatusTest.cs View File

@@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest
var s = new Status();
s.SetStatus(TF_Code.TF_CANCELLED, "cancel");
Assert.AreEqual(s.Code, TF_Code.TF_CANCELLED);
// Assert.AreEqual(s.Message, "cancel");
Assert.AreEqual(s.Message, "cancel");
}

[TestMethod]


+ 1
- 2
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -19,11 +19,10 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
<PackageReference Include="NumSharp" Version="0.6.2" />
<PackageReference Include="NumSharp" Version="0.6.3" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 106
- 3
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -50,16 +50,66 @@ namespace TensorFlowNET.UnitTest
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(graph, buffer, s);
s.Check();
return GraphDef.Parser.ParseFrom(buffer);
var def = GraphDef.Parser.ParseFrom(buffer);
buffer.Dispose();
s.Dispose();
return def;
}

public static bool GetNodeDef(Operation oper, ref NodeDef node_def)
public static NodeDef GetNodeDef(Operation oper)
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_OperationToNodeDef(oper, buffer, s);
s.Check();
var ret = NodeDef.Parser.ParseFrom(buffer);
buffer.Dispose();
s.Dispose();
return ret;
}

return s.Code == TF_Code.TF_OK;
public static bool IsAddN(NodeDef node_def, int n)
{
if (node_def.Op != "AddN" || node_def.Name != "add" ||
node_def.Input.Count != n)
{
return false;
}
bool found_t = false;
bool found_n = false;
foreach (var attr in node_def.Attr)
{
if (attr.Key == "T")
{
if (attr.Value.Type == DataType.DtInt32)
{
found_t = true;
}
else
{
return false;
}
}
else if (attr.Key == "N")
{
if (attr.Value.I == n)
{
found_n = true;
}
else
{
return false;
}
}
}

return found_t && found_n;
}

public static bool IsNeg(NodeDef node_def, string input)
{
return node_def.Op == "Neg" && node_def.Name == "neg" &&
node_def.Input.Count == 1 && node_def.Input[0] == input;
}

public static bool IsPlaceholder(NodeDef node_def)
@@ -93,6 +143,59 @@ namespace TensorFlowNET.UnitTest
return found_dtype && found_shape;
}

public static bool IsScalarConst(NodeDef node_def, int v)
{
if (node_def.Op != "Const" || node_def.Name != "scalar")
{
return false;
}
bool found_dtype = false;
bool found_value = false;
foreach (var attr in node_def.Attr) {
if (attr.Key == "dtype")
{
if (attr.Value.Type == DataType.DtInt32)
{
found_dtype = true;
}
else
{
return false;
}
}
else if (attr.Key == "value")
{
if (attr.Value.Tensor != null &&
attr.Value.Tensor.IntVal.Count == 1 &&
attr.Value.Tensor.IntVal[0] == v)
{
found_value = true;
}
else
{
return false;
}
}
}
return found_dtype && found_value;
}

public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg")
{
return NegHelper(n, graph, s, name);
}

public static Operation NegHelper(Operation n, Graph graph, Status s, string name)
{
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name);
var neg_input = new TF_Output(n, 0);
c_api.TF_AddInput(desc, neg_input);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();

return op;
}

public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);


Loading…
Cancel
Save