From dcfaa77490aec6ebbf63a82da9cbff32e8db08fa Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 14 Mar 2020 08:56:35 -0500 Subject: [PATCH] Create EagerTensor from NDArray. --- src/TensorFlowNET.Core/APIs/tf.linalg.cs | 2 +- src/TensorFlowNET.Core/Eager/EagerTensor.cs | 7 +++- .../Operations/gen_math_ops.cs | 36 ++++++++++++++++--- .../Tensors/Tensor.Value.cs | 2 ++ src/TensorFlowNET.Core/Tensors/constant_op.cs | 3 ++ 5 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index b4141460..398fd508 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -22,7 +22,7 @@ namespace Tensorflow => gen_array_ops.diag(diagonal, name: name); public Tensor matmul(Tensor a, Tensor b) - => gen_math_ops.mat_mul(a, b); + => math_ops.matmul(a, b); public Tensor batch_matmul(Tensor x, Tensor y) => gen_math_ops.batch_mat_mul(x, y); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index a659e0b6..6c61eb28 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; @@ -18,6 +19,10 @@ namespace Tensorflow.Eager { } + public EagerTensor(NDArray value, string device_name) : base(value) + { + } + public override string ToString() { switch (rank) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index aa0a6785..80c27c73 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -638,6 +638,14 @@ namespace Tensorflow /// public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = null) { + if (tf.context.executing_eagerly()) + { + var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, + "MatMul", name, null, + a, b, "transpose_a", transpose_a, "transpose_b", transpose_b); + return _result; + } + var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b }); return _op.output; @@ -738,10 +746,18 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, - "Sum", name, null, - input, axis, "keep_dims", keep_dims); - return _result; + try + { + var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, + "Sum", name, null, + input, axis, "keep_dims", keep_dims); + return _result; + } + catch (Exception) + { + return _sum_eager_fallback(input as Tensor[], axis as Tensor, + keep_dims: keep_dims, name: name, ctx: tf.context); + } } var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); @@ -749,6 +765,18 @@ namespace Tensorflow return _op.outputs[0]; } + private static Tensor _sum_eager_fallback(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null, Context ctx = null) + { + var (_attr_T, input) = _execute.args_to_matching_eager(inputs, ctx); + var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(new[] { axis }, ctx, TF_DataType.TF_INT32); + var _inputs_flat = new Tensor[] { input, axis1 }; + + var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx }; + + var _result = _execute.execute(ctx, "Sum", _inputs_flat, _attrs, name: name); + return _result; + } + /// /// Creates a sequence of numbers. /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index 0b51bec5..72c660bb 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -163,6 +163,8 @@ namespace Tensorflow return StringData(); case TF_DataType.TF_INT32: return ToArray(); + case TF_DataType.TF_FLOAT: + return ToArray(); default: return BufferToArray(); } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 2635f1d4..0d47b8ba 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; using Tensorflow.Eager; @@ -84,6 +85,8 @@ namespace Tensorflow { switch (value) { + case NDArray nd: + return new EagerTensor(nd, ctx.device_name); case string str: return new EagerTensor(str, ctx.device_name); case int int32: