| @@ -10,6 +10,7 @@ namespace Tensorflow.Eager | |||||
| /// </summary> | /// </summary> | ||||
| public class pywrap_tfe_src | public class pywrap_tfe_src | ||||
| { | { | ||||
| static int kFastPathExecuteInputStartIndex = 0; | |||||
| public static EagerTensor TFE_Py_FastPathExecute(Context ctx, | public static EagerTensor TFE_Py_FastPathExecute(Context ctx, | ||||
| string device_name, | string device_name, | ||||
| string opName, | string opName, | ||||
| @@ -28,7 +29,7 @@ namespace Tensorflow.Eager | |||||
| // Set non-inferred attrs, including setting defaults if the attr is passed in | // Set non-inferred attrs, including setting defaults if the attr is passed in | ||||
| // as None. | // as None. | ||||
| for (int i = op_def.InputArg.Count; i < args_size; i += 2) | |||||
| for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2) | |||||
| { | { | ||||
| var attr_name = args[i].ToString(); | var attr_name = args[i].ToString(); | ||||
| var attr_value = args[i + 1]; | var attr_value = args[i + 1]; | ||||
| @@ -38,20 +39,39 @@ namespace Tensorflow.Eager | |||||
| if(attr_name == attr.Name) | if(attr_name == attr.Name) | ||||
| { | { | ||||
| SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); | SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); | ||||
| status.Check(true); | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| c_api.TFE_OpSetDevice(op, device_name, status); | c_api.TFE_OpSetDevice(op, device_name, status); | ||||
| status.Check(true); | |||||
| // Add inferred attrs and inputs. | |||||
| for (int i = 0; i < op_def.InputArg.Count; i++) | for (int i = 0; i < op_def.InputArg.Count; i++) | ||||
| { | { | ||||
| var input_arg = op_def.InputArg[i]; | var input_arg = op_def.InputArg[i]; | ||||
| int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; | |||||
| if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | ||||
| { | { | ||||
| c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0); | |||||
| attr_list_sizes[input_arg.NumberAttr] = 0; | |||||
| c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); | |||||
| attr_list_sizes[input_arg.NumberAttr] = len; | |||||
| if (len > 0) | |||||
| { | |||||
| var fast_input_array = (object[])args[i]; | |||||
| // First item adds the type attr. | |||||
| if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status)) | |||||
| return null; | |||||
| for (var j = 1; j < len; j++) | |||||
| { | |||||
| // Since the list is homogeneous, we don't need to re-add the attr. | |||||
| if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status)) | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | ||||
| { | { | ||||
| @@ -60,14 +80,7 @@ namespace Tensorflow.Eager | |||||
| else | else | ||||
| { | { | ||||
| // The item is a single item. | // The item is a single item. | ||||
| switch (args[i]) | |||||
| { | |||||
| case Tensor inputTensor: | |||||
| AddInputToOp(inputTensor, true, input_arg, op, status); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| AddInputToOp(args[i], true, input_arg, op, status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -106,13 +119,23 @@ namespace Tensorflow.Eager | |||||
| /// <param name="op"></param> | /// <param name="op"></param> | ||||
| /// <param name="status"></param> | /// <param name="status"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| private static bool AddInputToOp(Tensor input, | |||||
| private static bool AddInputToOp(object inputs, | |||||
| bool add_type_attr, | bool add_type_attr, | ||||
| ArgDef input_arg, | ArgDef input_arg, | ||||
| IntPtr op, | IntPtr op, | ||||
| Status status) | Status status) | ||||
| { | { | ||||
| var input_handle = c_api.TFE_NewTensorHandle(input, status); | |||||
| IntPtr input_handle = IntPtr.Zero; | |||||
| switch (inputs) | |||||
| { | |||||
| case Tensor input: | |||||
| input_handle = c_api.TFE_NewTensorHandle(input, status); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | ||||
| { | { | ||||