| @@ -333,6 +333,18 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="device_name"></param> | |||||
| /// <param name="op_name"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="args"></param> | |||||
| /// <param name="input_size"></param> | |||||
| /// <param name="set_op_attrs"></param> | |||||
| /// <param name="status"></param> | |||||
| /// <returns>EagerTensorHandle</returns> | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_FastPathExecute(IntPtr ctx, | public static extern IntPtr TFE_FastPathExecute(IntPtr ctx, | ||||
| string device_name, | string device_name, | ||||
| @@ -173,7 +173,7 @@ namespace Tensorflow.Eager | |||||
| return true; | return true; | ||||
| } | } | ||||
| private static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, int start_index, Status out_status) | |||||
| public static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, int start_index, Status out_status) | |||||
| { | { | ||||
| var len = attrs.Length; | var len = attrs.Length; | ||||
| for (int i = 0; i < len; i += 2) | for (int i = 0; i < len; i += 2) | ||||
| @@ -181,7 +181,7 @@ namespace Tensorflow.Eager | |||||
| var key = attrs[start_index + i].ToString(); | var key = attrs[start_index + i].ToString(); | ||||
| var value = attrs[start_index + i + 1]; | var value = attrs[start_index + i + 1]; | ||||
| byte is_list = 0; | |||||
| byte is_list = 0; | |||||
| var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, out_status); | var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, out_status); | ||||
| if (!out_status.ok()) return; | if (!out_status.ok()) return; | ||||
| if (is_list != 0) | if (is_list != 0) | ||||
| @@ -205,7 +205,7 @@ namespace Tensorflow.Eager | |||||
| /// <param name="attr_value"></param> | /// <param name="attr_value"></param> | ||||
| /// <param name="attr_list_sizes"></param> | /// <param name="attr_list_sizes"></param> | ||||
| /// <param name="status"></param> | /// <param name="status"></param> | ||||
| public static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, | |||||
| private static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, | |||||
| string attr_name, object attr_value, | string attr_name, object attr_value, | ||||
| Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
| Status status) | Status status) | ||||
| @@ -156,7 +156,7 @@ namespace Tensorflow | |||||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "Pack", name, | "Pack", name, | ||||
| values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 1, | values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 1, | ||||
| (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "axis", axis, null, status), | |||||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "axis", axis }, 0 , status), | |||||
| status); | status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
| @@ -138,16 +138,19 @@ namespace Tensorflow | |||||
| "Mean", name, | "Mean", name, | ||||
| new IntPtr[] | new IntPtr[] | ||||
| { | { | ||||
| (input as EagerTensor).EagerTensorHandle, | |||||
| (axis as EagerTensor).EagerTensorHandle | |||||
| input as EagerTensor, | |||||
| axis as EagerTensor | |||||
| }, 2, | }, 2, | ||||
| (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "keep_dims", keep_dims, null, status), | |||||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "keep_dims", keep_dims }, 0, status), | |||||
| status); | status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
| } | } | ||||
| catch (Exception) | catch (Exception) | ||||
| { | { | ||||
| /*tensors = c_api.TFE_Execute(tf.context, tf.context.device_name, op_name, | |||||
| inputs, attrs, num_outputs);*/ | |||||
| return mean_eager_fallback(input as Tensor[], axis as Tensor, keep_dims: keep_dims, name: name, ctx: tf.context); | return mean_eager_fallback(input as Tensor[], axis as Tensor, keep_dims: keep_dims, name: name, ctx: tf.context); | ||||
| } | } | ||||
| } | } | ||||
| @@ -220,8 +223,8 @@ namespace Tensorflow | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "Add", name, new IntPtr[] | "Add", name, new IntPtr[] | ||||
| { | { | ||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| x as EagerTensor, | |||||
| y as EagerTensor | |||||
| }, 2, null, status); | }, 2, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(_result); | return new EagerTensor(_result); | ||||
| @@ -595,12 +598,8 @@ namespace Tensorflow | |||||
| using var status = new Status(); | using var status = new Status(); | ||||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "Cast", name, | "Cast", name, | ||||
| new IntPtr[] { (x as EagerTensor).EagerTensorHandle }, 1, | |||||
| (op) => | |||||
| { | |||||
| wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "DstT", DstT, null, status); | |||||
| wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "Truncate", Truncate, null, status); | |||||
| }, | |||||
| new IntPtr[] { x as EagerTensor }, 1, | |||||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "DstT", DstT, "Truncate", Truncate }, 0, status), | |||||
| status); | status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
| @@ -649,8 +648,8 @@ namespace Tensorflow | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "Sub", name, new IntPtr[] | "Sub", name, new IntPtr[] | ||||
| { | { | ||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| x as EagerTensor, | |||||
| y as EagerTensor | |||||
| }, 2, null, status); | }, 2, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(_result); | return new EagerTensor(_result); | ||||
| @@ -743,8 +742,8 @@ namespace Tensorflow | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "Mul", name, new IntPtr[] | "Mul", name, new IntPtr[] | ||||
| { | { | ||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| x as EagerTensor, | |||||
| y as EagerTensor | |||||
| }, 2, null, status); | }, 2, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(_result); | return new EagerTensor(_result); | ||||
| @@ -785,8 +784,8 @@ namespace Tensorflow | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "RealDiv", name, new IntPtr[] | "RealDiv", name, new IntPtr[] | ||||
| { | { | ||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| x as EagerTensor, | |||||
| y as EagerTensor | |||||
| }, 2, null, status); | }, 2, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(_result); | return new EagerTensor(_result); | ||||
| @@ -988,9 +987,9 @@ namespace Tensorflow | |||||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "Range", name, new IntPtr[] | "Range", name, new IntPtr[] | ||||
| { | { | ||||
| (start as EagerTensor).EagerTensorHandle, | |||||
| (limit as EagerTensor).EagerTensorHandle, | |||||
| (delta as EagerTensor).EagerTensorHandle | |||||
| start as EagerTensor, | |||||
| limit as EagerTensor, | |||||
| delta as EagerTensor | |||||
| }, 3, null, status); | }, 3, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
| @@ -32,10 +32,10 @@ namespace Tensorflow | |||||
| using var status = new Status(); | using var status = new Status(); | ||||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "AssignVariableOp", name, | "AssignVariableOp", name, | ||||
| new[] | |||||
| new IntPtr[] | |||||
| { | { | ||||
| (resource as EagerTensor).EagerTensorHandle, | |||||
| (value as EagerTensor).EagerTensorHandle | |||||
| resource as EagerTensor, | |||||
| value as EagerTensor | |||||
| }, 2, null, status); | }, 2, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return tensor; | return tensor; | ||||
| @@ -53,7 +53,7 @@ namespace Tensorflow | |||||
| using var status = new Status(); | using var status = new Status(); | ||||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "VarIsInitializedOp", name, | "VarIsInitializedOp", name, | ||||
| new[] { (resource as EagerTensor).EagerTensorHandle }, | |||||
| new IntPtr[] { resource as EagerTensor }, | |||||
| 1, null, status); | 1, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
| @@ -80,13 +80,15 @@ namespace Tensorflow | |||||
| { | { | ||||
| using var status = new Status(); | using var status = new Status(); | ||||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "VarHandleOp", name, null, 0, op => | |||||
| { | |||||
| wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "container", container, null, status); | |||||
| wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shared_name", shared_name, null, status); | |||||
| wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status); | |||||
| wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shape", shape.dims, null, status); | |||||
| }, status); | |||||
| "VarHandleOp", name, null, 0, | |||||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] | |||||
| { | |||||
| "container", container, | |||||
| "shared_name", shared_name, | |||||
| "dtype", dtype, | |||||
| "shape", shape.dims | |||||
| }, 0, status), | |||||
| status); | |||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||
| } | } | ||||
| @@ -114,8 +116,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| using var status = new Status(); | using var status = new Status(); | ||||
| var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
| "ReadVariableOp", name, new IntPtr[] { (resource as EagerTensor).EagerTensorHandle }, 1, | |||||
| (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status), | |||||
| "ReadVariableOp", name, | |||||
| new IntPtr[] { resource as EagerTensor }, 1, | |||||
| op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "dtype", dtype }, 0, status), | |||||
| status); | status); | ||||
| status.Check(true); | status.Check(true); | ||||
| return new EagerTensor(tensor); | return new EagerTensor(tensor); | ||||