Browse Source

GetNodeDef

tags/v0.1.0-Tensor
Oceania2018 7 years ago
parent
commit
740ca28965
12 changed files with 44 additions and 23 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  3. +6
    -3
      src/TensorFlowNET.Core/Operations/ops.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +1
    -6
      src/TensorFlowNET.Core/c_api_util.cs
  7. +10
    -1
      test/TensorFlowNET.Examples/Program.cs
  8. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  9. +8
    -5
      test/TensorFlowNET.UnitTest/GraphTest.cs
  10. +1
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  11. +1
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  12. +9
    -4
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 1
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -63,6 +63,7 @@ namespace Tensorflow

_id_value = Graph._next_id();
_handle = ops._create_c_op(g, node_def, inputs);
NumOutputs = c_api.TF_OperationNumOutputs(_handle);

_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)


+ 3
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -90,6 +90,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);

[DllImport(TensorFlowLibName)]
public static extern void TF_OperationToNodeDef(IntPtr oper, IntPtr buffer, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status);



+ 6
- 3
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -6,6 +6,7 @@ using System.Threading;
using Tensorflow;
using node_def_pb2 = Tensorflow;
using Google.Protobuf;
using System.Linq;

namespace Tensorflow
{
@@ -27,12 +28,14 @@ namespace Tensorflow
var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name);

// Add inputs
if(inputs != null)
if(inputs != null && inputs.Count > 0)
{
foreach (var op_input in inputs)
/*foreach (var op_input in inputs)
{
c_api.TF_AddInput(op_desc, op_input._as_tf_output());
}
}*/

c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
}

var status = new Status();


+ 2
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -40,7 +40,8 @@ namespace Tensorflow
{
if(Code != TF_Code.TF_OK)
{
throw new Exception(Message);
Console.WriteLine(Message);
// throw new Exception(Message);
}
}



+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -118,7 +118,7 @@ namespace Tensorflow

public TF_Output _as_tf_output()
{
return c_api_util.tf_output(op, value_index);
return new TF_Output(op, value_index);
}

public T[] Data<T>()


+ 1
- 6
src/TensorFlowNET.Core/c_api_util.cs View File

@@ -8,12 +8,7 @@ namespace Tensorflow
{
public static TF_Output tf_output(IntPtr c_op, int index)
{
var ret = new TF_Output();
ret.oper = c_op;
ret.index = index;

return ret;
return new TF_Output(c_op, index);
}
}
}

+ 10
- 1
test/TensorFlowNET.Examples/Program.cs View File

@@ -12,7 +12,16 @@ namespace TensorFlowNET.Examples
foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample))))
{
var example = (IExample)Activator.CreateInstance(type);
example.Run();

try
{
example.Run();
}
catch (Exception ex)
{
Console.WriteLine(ex);
Console.ReadLine();
}
}
}
}


+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -10,6 +10,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 8
- 5
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -33,20 +33,23 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(0, feed.NumControlOutputs);

AttrValue attr_value = null;
c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s);
Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
Assert.AreEqual(attr_value.Type, DataType.DtInt32);

// Test not found errors in TF_Operation*() query functions.
// Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
// Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
// Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
// Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message);
Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
//Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
//Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message);

// Make a constant oper with the scalar "3".
var three = c_test_util.ScalarConst(3, graph, s);

// Add oper.
var add = c_test_util.Add(feed, three, graph, s);

NodeDef node_def = null;
c_test_util.GetNodeDef(feed, ref node_def);
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest
public void addInConstant()
{
var a = tf.constant(4.0f);
var b = tf.constant(5.0f);
var b = tf.placeholder(tf.float32);
var c = tf.add(a, b);

using (var sess = tf.Session())


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
//Assert.IsTrue(s.Code == TF_Code.TF_OK);

graph.Dispose();
// graph.Dispose();
s.Dispose();
}
}


+ 9
- 4
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -42,6 +42,15 @@ namespace TensorFlowNET.UnitTest
return s.Code == TF_Code.TF_OK;
}

public static bool GetNodeDef(Operation oper, ref NodeDef node_def)
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_OperationToNodeDef(oper, buffer, s);

return s.Code == TF_Code.TF_OK;
}

public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
@@ -69,10 +78,6 @@ namespace TensorFlowNET.UnitTest
c_api.TF_SetAttrType(desc, "dtype", t.dtype);
op = c_api.TF_FinishOperation(desc, s);
s.Check();
if(op == null)
{
throw new Exception("c_api.TF_FinishOperation failed.");
}
}

public static Operation Const(Tensor t, Graph graph, Status s, string name)


Loading…
Cancel
Save