| @@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DA680126-DA60-4CE3-9094-72C355C081D3}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| @@ -29,10 +27,6 @@ Global | |||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -31,6 +31,25 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
| /// <summary> | |||
| /// Iterate through the operations of a graph. | |||
| /// </summary> | |||
| /// <param name="graph"></param> | |||
| /// <param name="pos"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos); | |||
| /// <summary> | |||
| /// Returns the operation in the graph with `oper_name`. Returns nullptr if | |||
| /// no operation found. | |||
| /// </summary> | |||
| /// <param name="graph"></param> | |||
| /// <param name="oper_name"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name); | |||
| /// <summary> | |||
| /// Sets the shape of the Tensor referenced by `output` in `graph` to | |||
| /// the shape described by `dims` and `num_dims`. | |||
| @@ -116,6 +116,8 @@ namespace Tensorflow | |||
| { | |||
| case IntPtr val: | |||
| return val == _handle; | |||
| case Operation val: | |||
| return val._handle == _handle; | |||
| } | |||
| return base.Equals(obj); | |||
| @@ -0,0 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class OperationDescription | |||
| { | |||
| private IntPtr _handle; | |||
| public OperationDescription(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| } | |||
| public static implicit operator OperationDescription(IntPtr handle) | |||
| { | |||
| return new OperationDescription(handle); | |||
| } | |||
| public static implicit operator IntPtr(OperationDescription desc) | |||
| { | |||
| return desc._handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -9,6 +9,7 @@ namespace Tensorflow | |||
| public struct TF_OperationDescription | |||
| { | |||
| public IntPtr node_builder; | |||
| //public TF_Graph graph; | |||
| public IntPtr graph; | |||
| public IntPtr colocation_constraints; | |||
| } | |||
| } | |||
| @@ -33,11 +33,7 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | |||
| <PackageReference Include="NumSharp" Version="0.6.2" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
| <PackageReference Include="NumSharp" Version="0.6.3" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -22,6 +22,7 @@ namespace Tensorflow | |||
| /// int32_t => int | |||
| /// int64_t* => long[] | |||
| /// size_t* => unlong[] | |||
| /// size_t* => ref uint | |||
| /// void* => IntPtr | |||
| /// string => IntPtr c_api.StringPiece(IntPtr) | |||
| /// </summary> | |||
| @@ -6,11 +6,10 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="NumSharp" Version="0.6.2" /> | |||
| <PackageReference Include="NumSharp" Version="0.6.3" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
| </ItemGroup> | |||
| @@ -104,21 +104,98 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.IsFalse(found_placeholder); | |||
| found_placeholder = true; | |||
| } | |||
| /*else if (IsScalarConst(n, 3)) | |||
| else if (c_test_util.IsScalarConst(n, 3)) | |||
| { | |||
| Assert.IsFalse(found_scalar_const); | |||
| found_scalar_const = true; | |||
| } | |||
| else if (IsAddN(n, 2)) | |||
| else if (c_test_util.IsAddN(n, 2)) | |||
| { | |||
| Assert.IsFalse(found_add); | |||
| found_add = true; | |||
| } | |||
| else | |||
| { | |||
| ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); | |||
| }*/ | |||
| Assert.Fail($"Unexpected NodeDef: {n}"); | |||
| } | |||
| } | |||
| Assert.IsTrue(found_placeholder); | |||
| Assert.IsTrue(found_scalar_const); | |||
| Assert.IsTrue(found_add); | |||
| // Add another oper to the graph. | |||
| var neg = c_test_util.Neg(add, graph, s); | |||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
| // Serialize to NodeDef. | |||
| var node_def = c_test_util.GetNodeDef(neg); | |||
| // Validate NodeDef is what we expect. | |||
| Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); | |||
| // Serialize to GraphDef. | |||
| var graph_def2 = c_test_util.GetGraphDef(graph); | |||
| // Compare with first GraphDef + added NodeDef. | |||
| graph_def.Node.Add(node_def); | |||
| Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); | |||
| // Look up some nodes by name. | |||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | |||
| Assert.AreEqual(neg, neg2); | |||
| var node_def2 = c_test_util.GetNodeDef(neg2); | |||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | |||
| Assert.AreEqual(feed, feed2); | |||
| node_def = c_test_util.GetNodeDef(feed); | |||
| node_def2 = c_test_util.GetNodeDef(feed2); | |||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||
| // Test iterating through the nodes of a graph. | |||
| found_placeholder = false; | |||
| found_scalar_const = false; | |||
| found_add = false; | |||
| bool found_neg = false; | |||
| uint pos = 0; | |||
| Operation oper; | |||
| while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | |||
| { | |||
| if (oper.Equals(feed)) | |||
| { | |||
| Assert.IsFalse(found_placeholder); | |||
| found_placeholder = true; | |||
| } | |||
| else if (oper.Equals(three)) | |||
| { | |||
| Assert.IsFalse(found_scalar_const); | |||
| found_scalar_const = true; | |||
| } | |||
| else if (oper.Equals(add)) | |||
| { | |||
| Assert.IsFalse(found_add); | |||
| found_add = true; | |||
| } | |||
| else if (oper.Equals(neg)) | |||
| { | |||
| Assert.IsFalse(found_neg); | |||
| found_neg = true; | |||
| } | |||
| else | |||
| { | |||
| node_def = c_test_util.GetNodeDef(oper); | |||
| Assert.Fail($"Unexpected Node: {node_def.ToString()}"); | |||
| } | |||
| } | |||
| Assert.IsTrue(found_placeholder); | |||
| Assert.IsTrue(found_scalar_const); | |||
| Assert.IsTrue(found_add); | |||
| Assert.IsTrue(found_neg); | |||
| graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest | |||
| var s = new Status(); | |||
| s.SetStatus(TF_Code.TF_CANCELLED, "cancel"); | |||
| Assert.AreEqual(s.Code, TF_Code.TF_CANCELLED); | |||
| // Assert.AreEqual(s.Message, "cancel"); | |||
| Assert.AreEqual(s.Message, "cancel"); | |||
| } | |||
| [TestMethod] | |||
| @@ -19,11 +19,10 @@ | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | |||
| <PackageReference Include="NumSharp" Version="0.6.2" /> | |||
| <PackageReference Include="NumSharp" Version="0.6.3" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
| </ItemGroup> | |||
| @@ -50,16 +50,66 @@ namespace TensorFlowNET.UnitTest | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
| s.Check(); | |||
| return GraphDef.Parser.ParseFrom(buffer); | |||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| s.Dispose(); | |||
| return def; | |||
| } | |||
| public static bool GetNodeDef(Operation oper, ref NodeDef node_def) | |||
| public static NodeDef GetNodeDef(Operation oper) | |||
| { | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_OperationToNodeDef(oper, buffer, s); | |||
| s.Check(); | |||
| var ret = NodeDef.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| s.Dispose(); | |||
| return ret; | |||
| } | |||
| return s.Code == TF_Code.TF_OK; | |||
| public static bool IsAddN(NodeDef node_def, int n) | |||
| { | |||
| if (node_def.Op != "AddN" || node_def.Name != "add" || | |||
| node_def.Input.Count != n) | |||
| { | |||
| return false; | |||
| } | |||
| bool found_t = false; | |||
| bool found_n = false; | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| if (attr.Key == "T") | |||
| { | |||
| if (attr.Value.Type == DataType.DtInt32) | |||
| { | |||
| found_t = true; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.Key == "N") | |||
| { | |||
| if (attr.Value.I == n) | |||
| { | |||
| found_n = true; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| return found_t && found_n; | |||
| } | |||
| public static bool IsNeg(NodeDef node_def, string input) | |||
| { | |||
| return node_def.Op == "Neg" && node_def.Name == "neg" && | |||
| node_def.Input.Count == 1 && node_def.Input[0] == input; | |||
| } | |||
| public static bool IsPlaceholder(NodeDef node_def) | |||
| @@ -93,6 +143,59 @@ namespace TensorFlowNET.UnitTest | |||
| return found_dtype && found_shape; | |||
| } | |||
| public static bool IsScalarConst(NodeDef node_def, int v) | |||
| { | |||
| if (node_def.Op != "Const" || node_def.Name != "scalar") | |||
| { | |||
| return false; | |||
| } | |||
| bool found_dtype = false; | |||
| bool found_value = false; | |||
| foreach (var attr in node_def.Attr) { | |||
| if (attr.Key == "dtype") | |||
| { | |||
| if (attr.Value.Type == DataType.DtInt32) | |||
| { | |||
| found_dtype = true; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.Key == "value") | |||
| { | |||
| if (attr.Value.Tensor != null && | |||
| attr.Value.Tensor.IntVal.Count == 1 && | |||
| attr.Value.Tensor.IntVal[0] == v) | |||
| { | |||
| found_value = true; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| return found_dtype && found_value; | |||
| } | |||
| public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | |||
| { | |||
| return NegHelper(n, graph, s, name); | |||
| } | |||
| public static Operation NegHelper(Operation n, Graph graph, Status s, string name) | |||
| { | |||
| OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||
| var neg_input = new TF_Output(n, 0); | |||
| c_api.TF_AddInput(desc, neg_input); | |||
| var op = c_api.TF_FinishOperation(desc, s); | |||
| s.Check(); | |||
| return op; | |||
| } | |||
| public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | |||
| { | |||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||