Browse Source

placeholder in operation

tags/v0.1.0-Tensor
Oceania2018 7 years ago
parent
commit
9f2065e4aa
4 changed files with 60 additions and 3 deletions
  1. +16
    -0
      src/TensorFlowNET.Core/Eager/Context.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Operation.cs
  3. +21
    -2
      src/TensorFlowNET.Core/tf.cs
  4. +12
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 16
- 0
src/TensorFlowNET.Core/Eager/Context.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace TensorFlowNET.Core.Eager
{
public class Context
{
public static int GRAPH_MODE = 0;
public static int EAGER_MODE = 1;

public int default_execution_mode;


}
}

+ 11
- 0
src/TensorFlowNET.Core/Operation.cs View File

@@ -17,6 +17,17 @@ namespace TensorFlowNET.Core
public Tensor[] outputs => _outputs;
public Tensor[] inputs;

public Operation(Graph g, string opType, string oper_name)
{
_graph = g;

var status = new Status();

var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name);
c_api.TF_SetAttrType(desc, "dtype", DataType.DtInt32);
c_api.TF_FinishOperation(desc, status.Handle);
}

public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
_graph = g;


+ 21
- 2
src/TensorFlowNET.Core/tf.cs View File

@@ -5,6 +5,7 @@ using System.Text;
using TF_DataType = Tensorflow.DataType;
using attr_value_pb2 = Tensorflow;
using Tensorflow;
using TensorFlowNET.Core.Eager;

namespace TensorFlowNET.Core
{
@@ -12,11 +13,25 @@ namespace TensorFlowNET.Core
{
public static DataType float32 = DataType.DtFloat;

public static Context context = new Context();

public static Graph g = new Graph(c_api.TF_NewGraph());

public delegate void Deallocator(IntPtr data, IntPtr size, IntPtr deallocatorData);

public static unsafe Tensor add(Tensor a, Tensor b)
{
return null;
}

public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null)
{
return gen_array_ops.placeholder(dtype, shape);
var g = ops.get_default_graph();
var op = new Operation(g, "Placeholder", "feed");

var tensor = new Tensor(op, 0, dtype);
//return gen_array_ops.placeholder(dtype, shape);
return tensor;
}

public static unsafe Tensor constant(object value)
@@ -38,6 +53,11 @@ namespace TensorFlowNET.Core
return const_tensor;
}

public static void enable_eager_execution()
{
context.default_execution_mode = Context.EAGER_MODE;
}

public static Deallocator FreeTensorDataDelegate = FreeTensorData;

[MonoPInvokeCallback(typeof(Deallocator))]
@@ -55,7 +75,6 @@ namespace TensorFlowNET.Core

public static Graph Graph()
{
Graph g = new Graph(c_api.TF_NewGraph());
return g;
}



+ 12
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -18,7 +18,18 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void placeholder()
{
var x = tf.placeholder(tf.float32, shape: new TensorShape(1024, 1024));
var x = tf.placeholder(tf.float32);
}

[TestMethod]
public void add()
{
var a = tf.placeholder(tf.float32);
var b = tf.placeholder(tf.float32);
var c = tf.add(a, b);

//sess.run(adder_node, { a: 3, b: 4.5})
//sess.run(adder_node, {a: [1,3], b: [2, 4]})
}
}
}

Loading…
Cancel
Save