| @@ -10,6 +10,7 @@ namespace Tensorflow.Eager | |||
| /// </summary> | |||
| public class pywrap_tfe_src | |||
| { | |||
| static int kFastPathExecuteInputStartIndex = 0; | |||
| public static EagerTensor TFE_Py_FastPathExecute(Context ctx, | |||
| string device_name, | |||
| string opName, | |||
| @@ -28,7 +29,7 @@ namespace Tensorflow.Eager | |||
| // Set non-inferred attrs, including setting defaults if the attr is passed in | |||
| // as None. | |||
| for (int i = op_def.InputArg.Count; i < args_size; i += 2) | |||
| for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2) | |||
| { | |||
| var attr_name = args[i].ToString(); | |||
| var attr_value = args[i + 1]; | |||
| @@ -38,20 +39,39 @@ namespace Tensorflow.Eager | |||
| if(attr_name == attr.Name) | |||
| { | |||
| SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); | |||
| status.Check(true); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| c_api.TFE_OpSetDevice(op, device_name, status); | |||
| status.Check(true); | |||
| // Add inferred attrs and inputs. | |||
| for (int i = 0; i < op_def.InputArg.Count; i++) | |||
| { | |||
| var input_arg = op_def.InputArg[i]; | |||
| int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; | |||
| if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||
| { | |||
| c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0); | |||
| attr_list_sizes[input_arg.NumberAttr] = 0; | |||
| c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); | |||
| attr_list_sizes[input_arg.NumberAttr] = len; | |||
| if (len > 0) | |||
| { | |||
| var fast_input_array = (object[])args[i]; | |||
| // First item adds the type attr. | |||
| if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status)) | |||
| return null; | |||
| for (var j = 1; j < len; j++) | |||
| { | |||
| // Since the list is homogeneous, we don't need to re-add the attr. | |||
| if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status)) | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||
| { | |||
| @@ -60,14 +80,7 @@ namespace Tensorflow.Eager | |||
| else | |||
| { | |||
| // The item is a single item. | |||
| switch (args[i]) | |||
| { | |||
| case Tensor inputTensor: | |||
| AddInputToOp(inputTensor, true, input_arg, op, status); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| } | |||
| AddInputToOp(args[i], true, input_arg, op, status); | |||
| } | |||
| } | |||
| @@ -106,13 +119,23 @@ namespace Tensorflow.Eager | |||
| /// <param name="op"></param> | |||
| /// <param name="status"></param> | |||
| /// <returns></returns> | |||
| private static bool AddInputToOp(Tensor input, | |||
| private static bool AddInputToOp(object inputs, | |||
| bool add_type_attr, | |||
| ArgDef input_arg, | |||
| IntPtr op, | |||
| Status status) | |||
| { | |||
| var input_handle = c_api.TFE_NewTensorHandle(input, status); | |||
| IntPtr input_handle = IntPtr.Zero; | |||
| switch (inputs) | |||
| { | |||
| case Tensor input: | |||
| input_handle = c_api.TFE_NewTensorHandle(input, status); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| } | |||
| if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | |||
| { | |||