Browse Source

finished #140

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
c56e2c89a3
4 changed files with 14 additions and 9 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
  3. +0
    -6
      src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs
  4. +5
    -2
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 4
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -38,6 +38,10 @@ namespace Tensorflow
return; return;


_handle = handle; _handle = handle;

_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
} }


public Operation(Graph g, string opType, string oper_name) public Operation(Graph g, string opType, string oper_name)
@@ -99,7 +103,6 @@ namespace Tensorflow


// Initialize self._outputs. // Initialize self._outputs.
output_types = new TF_DataType[NumOutputs]; output_types = new TF_DataType[NumOutputs];

for (int i = 0; i < NumOutputs; i++) for (int i = 0; i < NumOutputs; i++)
output_types[i] = OutputType(i); output_types[i] = OutputType(i);




+ 5
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -16,6 +16,11 @@ namespace Tensorflow
return constant_op.Constant(scalar); return constant_op.Constant(scalar);
} }


public static implicit operator int(Tensor tensor)
{
return tensor.Data<int>()[0];
}

public static implicit operator IntPtr(Tensor tensor) public static implicit operator IntPtr(Tensor tensor)
{ {
return tensor._handle; return tensor._handle;


+ 0
- 6
src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs View File

@@ -23,12 +23,6 @@ namespace Tensorflow


public static implicit operator RefVariable(Tensor var) public static implicit operator RefVariable(Tensor var)
{ {
switch (var.dtype)
{
case TF_DataType.TF_INT32:
return tf.Variable(var.Data<int>()[0]);
}

return null; return null;
} }
} }


+ 5
- 2
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -43,7 +43,8 @@ namespace TensorFlowNET.UnitTest
[TestMethod] [TestMethod]
public void Add() public void Add()
{ {
var x = tf.Variable(10, name: "x");
int result = 0;
Tensor x = tf.Variable(10, name: "x");


var model = tf.global_variables_initializer(); var model = tf.global_variables_initializer();
using (var session = tf.Session()) using (var session = tf.Session())
@@ -52,10 +53,12 @@ namespace TensorFlowNET.UnitTest
for(int i = 0; i < 5; i++) for(int i = 0; i < 5; i++)
{ {
x = x + 1; x = x + 1;
var result = session.run(x);
result = session.run(x);
print(result); print(result);
} }
} }

Assert.AreEqual(15, result);
} }
} }
} }

Loading…
Cancel
Save