Browse Source

GetNodeDef

tags/v0.1.0-Tensor
Oceania2018 7 years ago
parent
commit
740ca28965
12 changed files with 44 additions and 23 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  3. +6
    -3
      src/TensorFlowNET.Core/Operations/ops.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +1
    -6
      src/TensorFlowNET.Core/c_api_util.cs
  7. +10
    -1
      test/TensorFlowNET.Examples/Program.cs
  8. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  9. +8
    -5
      test/TensorFlowNET.UnitTest/GraphTest.cs
  10. +1
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  11. +1
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  12. +9
    -4
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 1
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -63,6 +63,7 @@ namespace Tensorflow


_id_value = Graph._next_id(); _id_value = Graph._next_id();
_handle = ops._create_c_op(g, node_def, inputs); _handle = ops._create_c_op(g, node_def, inputs);
NumOutputs = c_api.TF_OperationNumOutputs(_handle);


_outputs = new Tensor[NumOutputs]; _outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++) for (int i = 0; i < NumOutputs; i++)


+ 3
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -90,6 +90,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);


[DllImport(TensorFlowLibName)]
public static extern void TF_OperationToNodeDef(IntPtr oper, IntPtr buffer, IntPtr status);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status);




+ 6
- 3
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -6,6 +6,7 @@ using System.Threading;
using Tensorflow; using Tensorflow;
using node_def_pb2 = Tensorflow; using node_def_pb2 = Tensorflow;
using Google.Protobuf; using Google.Protobuf;
using System.Linq;


namespace Tensorflow namespace Tensorflow
{ {
@@ -27,12 +28,14 @@ namespace Tensorflow
var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name);


// Add inputs // Add inputs
if(inputs != null)
if(inputs != null && inputs.Count > 0)
{ {
foreach (var op_input in inputs)
/*foreach (var op_input in inputs)
{ {
c_api.TF_AddInput(op_desc, op_input._as_tf_output()); c_api.TF_AddInput(op_desc, op_input._as_tf_output());
}
}*/

c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
} }


var status = new Status(); var status = new Status();


+ 2
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -40,7 +40,8 @@ namespace Tensorflow
{ {
if(Code != TF_Code.TF_OK) if(Code != TF_Code.TF_OK)
{ {
throw new Exception(Message);
Console.WriteLine(Message);
// throw new Exception(Message);
} }
} }




+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -118,7 +118,7 @@ namespace Tensorflow


public TF_Output _as_tf_output() public TF_Output _as_tf_output()
{ {
return c_api_util.tf_output(op, value_index);
return new TF_Output(op, value_index);
} }


public T[] Data<T>() public T[] Data<T>()


+ 1
- 6
src/TensorFlowNET.Core/c_api_util.cs View File

@@ -8,12 +8,7 @@ namespace Tensorflow
{ {
public static TF_Output tf_output(IntPtr c_op, int index) public static TF_Output tf_output(IntPtr c_op, int index)
{ {
var ret = new TF_Output();
ret.oper = c_op;
ret.index = index;

return ret;
return new TF_Output(c_op, index);
} }
} }
} }

+ 10
- 1
test/TensorFlowNET.Examples/Program.cs View File

@@ -12,7 +12,16 @@ namespace TensorFlowNET.Examples
foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample))))
{ {
var example = (IExample)Activator.CreateInstance(type); var example = (IExample)Activator.CreateInstance(type);
example.Run();

try
{
example.Run();
}
catch (Exception ex)
{
Console.WriteLine(ex);
Console.ReadLine();
}
} }
} }
} }


+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -10,6 +10,7 @@
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup> </ItemGroup>




+ 8
- 5
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -33,20 +33,23 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(0, feed.NumControlOutputs); Assert.AreEqual(0, feed.NumControlOutputs);


AttrValue attr_value = null; AttrValue attr_value = null;
c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s);
Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
Assert.AreEqual(attr_value.Type, DataType.DtInt32); Assert.AreEqual(attr_value.Type, DataType.DtInt32);


// Test not found errors in TF_Operation*() query functions. // Test not found errors in TF_Operation*() query functions.
// Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
// Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
// Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
// Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message);
Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
//Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
//Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message);


// Make a constant oper with the scalar "3". // Make a constant oper with the scalar "3".
var three = c_test_util.ScalarConst(3, graph, s); var three = c_test_util.ScalarConst(3, graph, s);


// Add oper. // Add oper.
var add = c_test_util.Add(feed, three, graph, s); var add = c_test_util.Add(feed, three, graph, s);

NodeDef node_def = null;
c_test_util.GetNodeDef(feed, ref node_def);
} }
} }
} }

+ 1
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest
public void addInConstant() public void addInConstant()
{ {
var a = tf.constant(4.0f); var a = tf.constant(4.0f);
var b = tf.constant(5.0f);
var b = tf.placeholder(tf.float32);
var c = tf.add(a, b); var c = tf.add(a, b);


using (var sess = tf.Session()) using (var sess = tf.Session())


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
//Assert.IsTrue(s.Code == TF_Code.TF_OK); //Assert.IsTrue(s.Code == TF_Code.TF_OK);


graph.Dispose();
// graph.Dispose();
s.Dispose(); s.Dispose();
} }
} }


+ 9
- 4
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -42,6 +42,15 @@ namespace TensorFlowNET.UnitTest
return s.Code == TF_Code.TF_OK; return s.Code == TF_Code.TF_OK;
} }


public static bool GetNodeDef(Operation oper, ref NodeDef node_def)
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_OperationToNodeDef(oper, buffer, s);

return s.Code == TF_Code.TF_OK;
}

public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
{ {
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
@@ -69,10 +78,6 @@ namespace TensorFlowNET.UnitTest
c_api.TF_SetAttrType(desc, "dtype", t.dtype); c_api.TF_SetAttrType(desc, "dtype", t.dtype);
op = c_api.TF_FinishOperation(desc, s); op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();
if(op == null)
{
throw new Exception("c_api.TF_FinishOperation failed.");
}
} }


public static Operation Const(Tensor t, Graph graph, Status s, string name) public static Operation Const(Tensor t, Graph graph, Status s, string name)


Loading…
Cancel
Save