| @@ -38,6 +38,10 @@ namespace Tensorflow | |||||
| return; | return; | ||||
| _handle = handle; | _handle = handle; | ||||
| _outputs = new Tensor[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
| } | } | ||||
| public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
| @@ -99,7 +103,6 @@ namespace Tensorflow | |||||
| // Initialize self._outputs. | // Initialize self._outputs. | ||||
| output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| output_types[i] = OutputType(i); | output_types[i] = OutputType(i); | ||||
| @@ -16,6 +16,11 @@ namespace Tensorflow | |||||
| return constant_op.Constant(scalar); | return constant_op.Constant(scalar); | ||||
| } | } | ||||
| public static implicit operator int(Tensor tensor) | |||||
| { | |||||
| return tensor.Data<int>()[0]; | |||||
| } | |||||
| public static implicit operator IntPtr(Tensor tensor) | public static implicit operator IntPtr(Tensor tensor) | ||||
| { | { | ||||
| return tensor._handle; | return tensor._handle; | ||||
| @@ -23,12 +23,6 @@ namespace Tensorflow | |||||
| public static implicit operator RefVariable(Tensor var) | public static implicit operator RefVariable(Tensor var) | ||||
| { | { | ||||
| switch (var.dtype) | |||||
| { | |||||
| case TF_DataType.TF_INT32: | |||||
| return tf.Variable(var.Data<int>()[0]); | |||||
| } | |||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| @@ -43,7 +43,8 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Add() | public void Add() | ||||
| { | { | ||||
| var x = tf.Variable(10, name: "x"); | |||||
| int result = 0; | |||||
| Tensor x = tf.Variable(10, name: "x"); | |||||
| var model = tf.global_variables_initializer(); | var model = tf.global_variables_initializer(); | ||||
| using (var session = tf.Session()) | using (var session = tf.Session()) | ||||
| @@ -52,10 +53,12 @@ namespace TensorFlowNET.UnitTest | |||||
| for(int i = 0; i < 5; i++) | for(int i = 0; i < 5; i++) | ||||
| { | { | ||||
| x = x + 1; | x = x + 1; | ||||
| var result = session.run(x); | |||||
| result = session.run(x); | |||||
| print(result); | print(result); | ||||
| } | } | ||||
| } | } | ||||
| Assert.AreEqual(15, result); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||