diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 48f0a5d5..6b542382 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -333,6 +333,18 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// EagerTensorHandle [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_FastPathExecute(IntPtr ctx, string device_name, diff --git a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs index 7b4226f9..297a7a83 100644 --- a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs @@ -173,7 +173,7 @@ namespace Tensorflow.Eager return true; } - private static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, int start_index, Status out_status) + public static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, int start_index, Status out_status) { var len = attrs.Length; for (int i = 0; i < len; i += 2) @@ -181,7 +181,7 @@ namespace Tensorflow.Eager var key = attrs[start_index + i].ToString(); var value = attrs[start_index + i + 1]; - byte is_list = 0; + byte is_list = 0; var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, out_status); if (!out_status.ok()) return; if (is_list != 0) @@ -205,7 +205,7 @@ namespace Tensorflow.Eager /// /// /// - public static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, + private static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, string attr_name, object attr_value, Dictionary attr_list_sizes, Status status) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 70509ad5..00483e2e 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -156,7 +156,7 @@ namespace Tensorflow var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 1, - (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "axis", axis, null, status), + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "axis", axis }, 0 , status), status); status.Check(true); return new EagerTensor(tensor); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index c1a9a0db..6f481f31 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -138,16 +138,19 @@ namespace Tensorflow "Mean", name, new IntPtr[] { - (input as EagerTensor).EagerTensorHandle, - (axis as EagerTensor).EagerTensorHandle + input as EagerTensor, + axis as EagerTensor }, 2, - (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "keep_dims", keep_dims, null, status), + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "keep_dims", keep_dims }, 0, status), status); status.Check(true); return new EagerTensor(tensor); } catch (Exception) { + /*tensors = c_api.TFE_Execute(tf.context, tf.context.device_name, op_name, + inputs, attrs, num_outputs);*/ + return mean_eager_fallback(input as Tensor[], axis as Tensor, keep_dims: keep_dims, name: name, ctx: tf.context); } } @@ -220,8 +223,8 @@ namespace Tensorflow var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Add", name, new IntPtr[] { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle + x as EagerTensor, + y as EagerTensor }, 2, null, status); status.Check(true); return new EagerTensor(_result); @@ -595,12 +598,8 @@ namespace Tensorflow using var status = new Status(); var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Cast", name, - new IntPtr[] { (x as EagerTensor).EagerTensorHandle }, 1, - (op) => - { - wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "DstT", DstT, null, status); - wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "Truncate", Truncate, null, status); - }, + new IntPtr[] { x as EagerTensor }, 1, + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "DstT", DstT, "Truncate", Truncate }, 0, status), status); status.Check(true); return new EagerTensor(tensor); @@ -649,8 +648,8 @@ namespace Tensorflow var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Sub", name, new IntPtr[] { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle + x as EagerTensor, + y as EagerTensor }, 2, null, status); status.Check(true); return new EagerTensor(_result); @@ -743,8 +742,8 @@ namespace Tensorflow var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Mul", name, new IntPtr[] { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle + x as EagerTensor, + y as EagerTensor }, 2, null, status); status.Check(true); return new EagerTensor(_result); @@ -785,8 +784,8 @@ namespace Tensorflow var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "RealDiv", name, new IntPtr[] { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle + x as EagerTensor, + y as EagerTensor }, 2, null, status); status.Check(true); return new EagerTensor(_result); @@ -988,9 +987,9 @@ namespace Tensorflow var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Range", name, new IntPtr[] { - (start as EagerTensor).EagerTensorHandle, - (limit as EagerTensor).EagerTensorHandle, - (delta as EagerTensor).EagerTensorHandle + start as EagerTensor, + limit as EagerTensor, + delta as EagerTensor }, 3, null, status); status.Check(true); return new EagerTensor(tensor); diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index 97079aa7..e33bb66c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -32,10 +32,10 @@ namespace Tensorflow using var status = new Status(); var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "AssignVariableOp", name, - new[] + new IntPtr[] { - (resource as EagerTensor).EagerTensorHandle, - (value as EagerTensor).EagerTensorHandle + resource as EagerTensor, + value as EagerTensor }, 2, null, status); status.Check(true); return tensor; @@ -53,7 +53,7 @@ namespace Tensorflow using var status = new Status(); var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "VarIsInitializedOp", name, - new[] { (resource as EagerTensor).EagerTensorHandle }, + new IntPtr[] { resource as EagerTensor }, 1, null, status); status.Check(true); return new EagerTensor(tensor); @@ -80,13 +80,15 @@ namespace Tensorflow { using var status = new Status(); var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "VarHandleOp", name, null, 0, op => - { - wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "container", container, null, status); - wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shared_name", shared_name, null, status); - wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status); - wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shape", shape.dims, null, status); - }, status); + "VarHandleOp", name, null, 0, + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] + { + "container", container, + "shared_name", shared_name, + "dtype", dtype, + "shape", shape.dims + }, 0, status), + status); status.Check(true); return new EagerTensor(tensor); } @@ -114,8 +116,9 @@ namespace Tensorflow { using var status = new Status(); var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "ReadVariableOp", name, new IntPtr[] { (resource as EagerTensor).EagerTensorHandle }, 1, - (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status), + "ReadVariableOp", name, + new IntPtr[] { resource as EagerTensor }, 1, + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "dtype", dtype }, 0, status), status); status.Check(true); return new EagerTensor(tensor);