diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 741b4ccb..48672e32 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -6,29 +6,29 @@ namespace Tensorflow { public static partial class tf { - public static unsafe Tensor add(Tensor a, Tensor b) + public static Tensor add(Tensor a, Tensor b) { return gen_math_ops.add(a, b); } - public static unsafe Tensor sub(Tensor a, Tensor b) + public static Tensor sub(Tensor a, Tensor b) { return gen_math_ops.sub(a, b); } - public static unsafe Tensor add(Tensor a, RefVariable b) + public static Tensor multiply(Tensor x, Tensor y) { - return gen_math_ops.add(a, b); + return gen_math_ops.mul(x, y); } - public static unsafe Tensor multiply(Tensor x, Tensor y) + public static Tensor pow(Tensor x, Tensor y) { - return gen_math_ops.mul(x, y); + return gen_math_ops.pow(x, y); } - public static unsafe Tensor multiply(Tensor x, RefVariable y) + public static Tensor reduce_sum(Tensor input, int? axis = null) { - return gen_math_ops.mul(x, y); + return gen_math_ops.sum(input, input); } } } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 8e19be1c..9b9076c8 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -36,50 +36,27 @@ namespace Tensorflow foreach (var input_arg in op_def.InputArg) { var input_name = input_arg.Name; - switch (keywords[input_name]) + if (keywords[input_name] is Tensor value) { - case Tensor value: - if (keywords.ContainsKey(input_name)) - { - inputs.Add(value); - } - - if (!String.IsNullOrEmpty(input_arg.TypeAttr)) - { - attrs[input_arg.TypeAttr] = value.dtype; - } - - if (input_arg.IsRef) - { - - } - else - { - input_types.Add(value.dtype); - } - break; - case RefVariable value: - if (keywords.ContainsKey(input_name)) - { - inputs.Add(value._initial_value); - } - - if (!String.IsNullOrEmpty(input_arg.TypeAttr)) - { - attrs[input_arg.TypeAttr] = value._initial_value.dtype; - } - - if (input_arg.IsRef) - { - - } - else - { - input_types.Add(value._initial_value.dtype); - } - break; + if (keywords.ContainsKey(input_name)) + { + inputs.Add(value); + } + + if (!String.IsNullOrEmpty(input_arg.TypeAttr)) + { + attrs[input_arg.TypeAttr] = value.dtype; + } + + if (input_arg.IsRef) + { + + } + else + { + input_types.Add(value.dtype); + } } - } // Process remaining attrs diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 938c954e..66e2811c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -20,59 +20,59 @@ namespace Tensorflow return new Tensor(_op, 0, _op.OutputType(0)); } - public static Tensor add(Tensor a, RefVariable b) + public static Tensor sub(Tensor x, Tensor y) { var keywords = new Dictionary(); - keywords.Add("x", a); - keywords.Add("y", b); + keywords.Add("x", x); + keywords.Add("y", y); - var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } - public static Tensor sub(Tensor x, Tensor y) + public static Tensor mul(Tensor x, Tensor y) { var keywords = new Dictionary(); keywords.Add("x", x); keywords.Add("y", y); - var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } - public static Tensor mul(Tensor x, Tensor y) + public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) { var keywords = new Dictionary(); - keywords.Add("x", x); - keywords.Add("y", y); + keywords.Add("a", a); + keywords.Add("b", b); + keywords.Add("transpose_a", transpose_a); + keywords.Add("transpose_b", transpose_b); - var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } - public static Tensor mul(Tensor x, RefVariable y) + public static Tensor pow(Tensor x, Tensor y) { var keywords = new Dictionary(); keywords.Add("x", x); keywords.Add("y", y); - var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Pow", name: "Pow", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } - public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) + public static Tensor sum(Tensor x, Tensor y) { var keywords = new Dictionary(); - keywords.Add("a", a); - keywords.Add("b", b); - keywords.Add("transpose_a", transpose_a); - keywords.Add("transpose_b", transpose_b); + keywords.Add("x", x); + keywords.Add("y", y); - var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Pow", name: "Pow", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index f74e3498..66138150 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -47,7 +47,7 @@ namespace Tensorflow return result; } - private unsafe object _run(Tensor fetches, Dictionary feed_dict = null) + private object _run(Tensor fetches, Dictionary feed_dict = null) { var feed_dict_tensor = new Dictionary(); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index 2653a9e0..f2ac09fc 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -11,6 +11,11 @@ namespace Tensorflow return gen_math_ops.add(t1, t2); } + public static Tensor operator -(Tensor t1, Tensor t2) + { + return gen_math_ops.sub(t1, t2); + } + public static Tensor operator *(Tensor t1, Tensor t2) { return gen_math_ops.mul(t1, t2); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 4281e7f8..dd37c8b2 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -211,6 +211,11 @@ namespace Tensorflow status.Dispose(); } + public static implicit operator Tensor(int scalar) + { + return new Tensor(scalar); + } + public static implicit operator IntPtr(Tensor tensor) { return tensor._handle; @@ -220,5 +225,10 @@ namespace Tensorflow { return new Tensor(handle); } + + public static implicit operator Tensor(RefVariable var) + { + return var._initial_value; + } } } diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index 633d7b82..3199a923 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -40,7 +40,8 @@ namespace TensorFlowNET.Examples var pred = tf.add(part1, b); // Mean squared error - var cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); + var pow = tf.pow(pred - Y, 2); + //var cost = tf.reduce_sum(pow) / (2 * n_samples); } } }