Browse Source

fix OperationsTest.addInPlaceholder.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
422bacfd80
8 changed files with 54 additions and 33 deletions
  1. +19
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +1
    -9
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  3. +4
    -2
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  4. +2
    -13
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  5. +1
    -5
      src/TensorFlowNET.Core/tf.cs
  6. +24
    -2
      test/TensorFlowNET.Examples/BasicOperations.cs
  7. +1
    -1
      test/TensorFlowNET.UnitTest/CSession.cs
  8. +2
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 19
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static unsafe Tensor add(Tensor a, Tensor b)
{
return gen_math_ops.add(a, b);
}

public static unsafe Tensor multiply(Tensor x, Tensor y)
{
return gen_math_ops.mul(x, y);
}
}
}

+ 1
- 9
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -12,13 +12,6 @@ namespace Tensorflow

public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
{
/*var g = ops.get_default_graph();
var op = new Operation(g, "Placeholder", "feed");

var tensor = new Tensor(op, 0, dtype);

return tensor;*/

var keywords = new Dictionary<string, object>();
keywords.Add("dtype", dtype);
keywords.Add("shape", shape);
@@ -31,8 +24,7 @@ namespace Tensorflow
_attrs["dtype"] = _op.get_attr("dtype");
_attrs["shape"] = _op.get_attr("shape");

var tensor = new Tensor(_op, 0, dtype);
return tensor;
return new Tensor(_op, 0, dtype);
}
}
}

+ 4
- 2
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -105,8 +105,8 @@ namespace Tensorflow
c_api.TF_SessionRun(_session,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: new IntPtr[] { },
ninputs: 0,
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
ninputs: feed_dict.Length,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
@@ -115,6 +115,8 @@ namespace Tensorflow
run_metadata: IntPtr.Zero,
status: status);

status.Check(true);

object[] result = new object[fetch_list.Length];

for (int i = 0; i < fetch_list.Length; i++)


+ 2
- 13
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -37,20 +37,9 @@ namespace Tensorflow

switch (type)
{
case TF_DataType.TF_INT32:
dtype = DataType.DtInt32;
break;
case TF_DataType.TF_FLOAT:
dtype = DataType.DtFloat;
break;
case TF_DataType.TF_DOUBLE:
dtype = DataType.DtDouble;
break;
case TF_DataType.TF_STRING:
dtype = DataType.DtString;
break;
default:
throw new Exception("Not Implemented");
Enum.TryParse(((int)type).ToString(), out dtype);
break;
}

return dtype;


+ 1
- 5
src/TensorFlowNET.Core/tf.cs View File

@@ -10,6 +10,7 @@ namespace Tensorflow
{
public static partial class tf
{
public static TF_DataType int16 = TF_DataType.TF_INT16;
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
public static TF_DataType chars = TF_DataType.TF_STRING;

@@ -22,11 +23,6 @@ namespace Tensorflow
return new RefVariable(data, dtype);
}

public static unsafe Tensor add(Tensor a, Tensor b)
{
return gen_math_ops.add(a, b);
}

public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
{
return gen_array_ops.placeholder(dtype, shape);


+ 24
- 2
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -11,6 +11,8 @@ namespace TensorFlowNET.Examples
/// </summary>
public class BasicOperations : IExample
{
private Session sess;

public void Run()
{
// Basic constant operations
@@ -18,14 +20,34 @@ namespace TensorFlowNET.Examples
// of the Constant op.
var a = tf.constant(2);
var b = tf.constant(3);
var c = a * b;
// Launch the default graph.
using (var sess = tf.Session())
using (sess = tf.Session())
{
Console.WriteLine("a=2, b=3");
Console.WriteLine($"Addition with constants: {sess.run(a + b)}");
Console.WriteLine($"Multiplication with constants: {sess.run(a * b)}");
}

// Basic Operations with variable as graph input
// The value returned by the constructor represents the output
// of the Variable op. (define as input when running session)
// tf Graph input
a = tf.placeholder(tf.int16);
b = tf.placeholder(tf.int16);

// Define some operations
var add = tf.add(a, b);
var mul = tf.multiply(a, b);

// Launch the default graph.
using(sess = tf.Session())
{
// var feed_dict = new Dictionary<string, >
// Run every operation with variable input
// Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}");
// Console.WriteLine($"Multiplication with variables: {}");
}
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -78,7 +78,7 @@ namespace TensorFlowNET.UnitTest
var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray();
IntPtr targets_ptr = IntPtr.Zero;

c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1,
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
outputs_ptr, output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count,
IntPtr.Zero, s);


+ 2
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -35,7 +35,8 @@ namespace TensorFlowNET.UnitTest
feed_dict.Add(a, 3.0f);
feed_dict.Add(b, 2.0f);

//var o = sess.run(c, feed_dict);
var o = sess.run(c, feed_dict);
Assert.AreEqual(o, 5.0f);
}
}



Loading…
Cancel
Save