diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
index b4141460..398fd508 100644
--- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
@@ -22,7 +22,7 @@ namespace Tensorflow
=> gen_array_ops.diag(diagonal, name: name);
public Tensor matmul(Tensor a, Tensor b)
- => gen_math_ops.mat_mul(a, b);
+ => math_ops.matmul(a, b);
public Tensor batch_matmul(Tensor x, Tensor y)
=> gen_math_ops.batch_mat_mul(x, y);
diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs
index a659e0b6..6c61eb28 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs
@@ -1,4 +1,5 @@
-using System;
+using NumSharp;
+using System;
using System.Collections.Generic;
using System.Text;
@@ -18,6 +19,10 @@ namespace Tensorflow.Eager
{
}
+ public EagerTensor(NDArray value, string device_name) : base(value)
+ {
+ }
+
public override string ToString()
{
switch (rank)
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index aa0a6785..80c27c73 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -638,6 +638,14 @@ namespace Tensorflow
///
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = null)
{
+ if (tf.context.executing_eagerly())
+ {
+ var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
+ "MatMul", name, null,
+ a, b, "transpose_a", transpose_a, "transpose_b", transpose_b);
+ return _result;
+ }
+
var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b });
return _op.output;
@@ -738,10 +746,18 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
- "Sum", name, null,
- input, axis, "keep_dims", keep_dims);
- return _result;
+ try
+ {
+ var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
+ "Sum", name, null,
+ input, axis, "keep_dims", keep_dims);
+ return _result;
+ }
+ catch (Exception)
+ {
+ return _sum_eager_fallback(input as Tensor[], axis as Tensor,
+ keep_dims: keep_dims, name: name, ctx: tf.context);
+ }
}
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });
@@ -749,6 +765,18 @@ namespace Tensorflow
return _op.outputs[0];
}
+ private static Tensor _sum_eager_fallback(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null, Context ctx = null)
+ {
+ var (_attr_T, input) = _execute.args_to_matching_eager(inputs, ctx);
+ var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(new[] { axis }, ctx, TF_DataType.TF_INT32);
+ var _inputs_flat = new Tensor[] { input, axis1 };
+
+ var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx };
+
+ var _result = _execute.execute(ctx, "Sum", _inputs_flat, _attrs, name: name);
+ return _result;
+ }
+
///
/// Creates a sequence of numbers.
///
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
index 0b51bec5..72c660bb 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
@@ -163,6 +163,8 @@ namespace Tensorflow
return StringData();
case TF_DataType.TF_INT32:
return ToArray();
+ case TF_DataType.TF_FLOAT:
+ return ToArray();
default:
return BufferToArray();
}
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index 2635f1d4..0d47b8ba 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using NumSharp;
using System;
using System.Collections.Generic;
using Tensorflow.Eager;
@@ -84,6 +85,8 @@ namespace Tensorflow
{
switch (value)
{
+ case NDArray nd:
+ return new EagerTensor(nd, ctx.device_name);
case string str:
return new EagerTensor(str, ctx.device_name);
case int int32: