diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs index 2f22865c..ee5d0765 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs @@ -30,49 +30,19 @@ namespace Tensorflow.Contexts public sealed partial class Context { // [DebuggerStepThrough] - public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args) - { - if (tf.Context.has_graph_arg(args)) - { - if (executing_eagerly()) - { - graph_mode(); - var result = graphAction(); - restore_mode(); - return result; - } - else - { - return graphAction(); - } - } - else - { - if (tf.Context.executing_eagerly()) - { - return eagerAction(); - } - else - { - return graphAction(); - } - } - } - - // [DebuggerStepThrough] - public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args) + public Tensors ExecuteOp(string OpType, string Name, AutoModeArgs args) { var inputArgs = ConvertToDict(args.OpInputArgs); var attrDict = ConvertToDict(args.OpAttrs); - Func graphAction = () => + Func graphAction = () => { foreach (var attr in attrDict) inputArgs[attr.Key] = attr.Value; - return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output; + return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).outputs; }; - Func eagerAction = () => + Func eagerAction = () => { var attrs = new object[attrDict.Count() * 2]; int i = 0; @@ -87,7 +57,7 @@ namespace Tensorflow.Contexts OpType, Name, null, inputArgs.Values.ToArray(), - attrs).FirstOrDefault(); + attrs); }; if (tf.Context.has_graph_arg(inputArgs.Values)) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index e2815f81..bc3cc735 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -269,29 +269,24 @@ namespace Tensorflow.Operations } public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name, - args: new - { - y_backprop = @params.YBackprop, - x = @params.X, - scale = @params.Scale, - reserve_space_1 = @params.ReserveSpace1, - reserve_space_2 = @params.ReserveSpace2, - reserve_space_3 = @params.ReserveSpace3, - epsilon = @params.Epsilon, - data_format = @params.DataFormat, - is_training = @params.IsTraining - }).outputs, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "FusedBatchNormGradV3", @params.Name, - null, - @params.YBackprop, @params.X, @params.Scale, - @params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3, - "epsilon", @params.Epsilon, - "data_format", @params.DataFormat, - "is_training", @params.IsTraining), - @params.YBackprop); + => tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name, new AutoModeArgs + { + OpInputArgs = new + { + y_backprop = @params.YBackprop, + x = @params.X, + scale = @params.Scale, + reserve_space_1 = @params.ReserveSpace1, + reserve_space_2 = @params.ReserveSpace2, + reserve_space_3 = @params.ReserveSpace3 + }, + OpAttrs = new + { + epsilon = @params.Epsilon, + data_format = @params.DataFormat, + is_training = @params.IsTraining + } + }); public static Tensor[] fused_batch_norm(Tensor x, Tensor scale, @@ -388,14 +383,10 @@ namespace Tensorflow.Operations } public static Tensor log_softmax(Tensor logits, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("LogSoftmax", name: name, - args: new { logits }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "LogSoftmax", name, - null, - logits).FirstOrDefault(), - logits); + => tf.Context.ExecuteOp("LogSoftmax", name, new AutoModeArgs + { + OpInputArgs = new { logits } + }); /// /// Says whether the targets are in the top `K` predictions. @@ -418,19 +409,11 @@ namespace Tensorflow.Operations } public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("LeakyRelu", name: name, - args: new - { - features, - alpha - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "LeakyRelu", name, - null, - features, - "alpha", alpha).FirstOrDefault(), - features); + => tf.Context.ExecuteOp("LeakyRelu", name, new AutoModeArgs + { + OpInputArgs = new { features }, + OpAttrs = new { alpha } + }); public static Tensor max_pool(Tensor input, int[] ksize, diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index bc4b1206..235d2a95 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -737,7 +737,7 @@ namespace Tensorflow public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0, long shrink_axis_mask = 0, string name = null) - => tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs + => tf.Context.ExecuteOp("StridedSliceGrad", name, new AutoModeArgs { OpInputArgs = new { @@ -960,7 +960,7 @@ namespace Tensorflow => gen_array_ops.slice(input, begin, size, name: name); public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) - => tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs + => tf.Context.ExecuteOp("Slice", name, new AutoModeArgs { OpInputArgs = new { input, begin, size }, GetGradientAttrs = (op) => new diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index e29227c4..fd1b4c8d 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -72,14 +72,10 @@ namespace Tensorflow } public static Tensor concat_v2(Tensor[] values, int axis, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ConcatV2", name: name, - args: new { values, axis }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ConcatV2", name, - null, - values, axis).FirstOrDefault(), - values); + => tf.Context.ExecuteOp("ConcatV2", name, new AutoModeArgs + { + OpInputArgs = new { values, axis } + }); private static Tensor concat_v2_eager_fallback(T1[] values, T2 axis, string name, Context ctx) { @@ -202,14 +198,11 @@ namespace Tensorflow } public static Tensor pack(Tensor[] values, int axis = 0, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Pack", name, new { values, axis }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Pack", name, - null, - values, - "axis", axis).FirstOrDefault(), - values, axis); + => tf.Context.ExecuteOp("Pack", name, new AutoModeArgs + { + OpInputArgs = new { values }, + OpAttrs = new { axis } + }); /// /// Return a tensor with the same shape and contents as the input tensor or value. @@ -326,31 +319,16 @@ namespace Tensorflow } public static Tensor reshape(Tensor tensor, T shape, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Reshape", name, - null, - tensor, shape).FirstOrDefault(), - tensor, shape); + => tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs + { + OpInputArgs = new { tensor, shape } + }); public static Tensor reshape(Tensor tensor, object[] shape, string name = null) - { - try + => tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs { - return tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Reshape", name, - null, - tensor, shape).FirstOrDefault(), - tensor, shape); - } - catch (InvalidArgumentError ex) - { - return reshape_eager_fallback(tensor, shape, name, tf.Context); - } - } + OpInputArgs = new { tensor, shape } + }); private static Tensor reshape_eager_fallback(Tensor tensor, object[] shape, string name, Context ctx) { @@ -467,15 +445,11 @@ namespace Tensorflow } public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Shape", name, - new { input, out_type }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Shape", name, - null, - input, - "out_type", out_type).FirstOrDefault(), - input); + => tf.Context.ExecuteOp("Shape", name, new AutoModeArgs + { + OpInputArgs = new { input }, + OpAttrs = new { out_type } + }); /// /// Returns shape of tensors. @@ -559,22 +533,16 @@ namespace Tensorflow } public static Tensor tile(Tensor input, Tensor multiples, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tile", name, - null, - input, multiples).FirstOrDefault(), - input, multiples); + => tf.Context.ExecuteOp("Tile", name, new AutoModeArgs + { + OpInputArgs = new { input, multiples } + }); public static Tensor tile(Tensor input, object[] multiples, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tile", name, - null, - input, multiples).FirstOrDefault(), - input, multiples); + => tf.Context.ExecuteOp("Tile", name, new AutoModeArgs + { + OpInputArgs = new { input, multiples } + }); public static Tensor transpose(Tensor x, T1 perm, string name = null) { @@ -592,22 +560,16 @@ namespace Tensorflow } public static Tensor ones_like(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "OnesLike", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("OnesLike", name, new AutoModeArgs + { + OpInputArgs = new { x } + }); public static Tensor zeros_like(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ZerosLike", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("ZerosLike", name, new AutoModeArgs + { + OpInputArgs = new { x } + }); public static Tensor stop_gradient(Tensor x, string name = null) { @@ -623,53 +585,37 @@ namespace Tensorflow long new_axis_mask = 0, long shrink_axis_mask = 0, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("StridedSlice", name, new + => tf.Context.ExecuteOp("StridedSlice", name, new AutoModeArgs { - input, - begin, - end, - strides, - begin_mask, - end_mask, - ellipsis_mask, - new_axis_mask, - shrink_axis_mask - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "StridedSlice", name, - null, - input, begin, end, strides, - "begin_mask", begin_mask, - "end_mask", end_mask, - "ellipsis_mask", ellipsis_mask, - "new_axis_mask", new_axis_mask, - "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), - input, begin, end, strides); - - public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, + OpInputArgs = new { input, begin, end, strides }, + OpAttrs = new + { + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + } + }); + + public static Tensor resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new + => tf.Context.ExecuteOp("ResourceStridedSliceAssign", name, new AutoModeArgs { - input, begin, end, strides, value, - begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResourceStridedSliceAssign", name, - null, - input, begin, end, strides, value, - "begin_mask", begin_mask, - "end_mask", end_mask, - "ellipsis_mask", ellipsis_mask, - "new_axis_mask", new_axis_mask, - "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), - input, begin, end, strides, value); + OpInputArgs = new { input, begin, end, strides, value }, + OpAttrs = new { + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + } + }); public static Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides, int begin_mask = 0, diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs index 955f2db3..19a5be0d 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs @@ -222,25 +222,15 @@ namespace Tensorflow public static Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, bool half_pixel_centers = false, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ResizeNearestNeighbor", name: name, args: new + => tf.Context.ExecuteOp("ResizeNearestNeighbor", name, new AutoModeArgs { - images, - size, - align_corners, - half_pixel_centers - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResizeNearestNeighbor", name, - null, - images, size, - "align_corners", align_corners, - "half_pixel_centers", half_pixel_centers).FirstOrDefault(), - images); + OpInputArgs = new { images, size }, + OpAttrs = new { align_corners, half_pixel_centers } + }); public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null) - => tf.Context.RunInAutoMode2("ResizeNearestNeighborGrad", name, new AutoModeArgs + => tf.Context.ExecuteOp("ResizeNearestNeighborGrad", name, new AutoModeArgs { OpInputArgs = new { grads, size }, OpAttrs = new { align_corners, half_pixel_centers }, diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index b40b3b91..ba981d6f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -116,13 +116,10 @@ namespace Tensorflow /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) /// public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DivNoNan", name, - null, - x, y).FirstOrDefault(), - x, y); + => tf.Context.ExecuteOp("DivNoNan", name, new AutoModeArgs + { + OpInputArgs = new { x, y } + }); public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) => mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); @@ -141,7 +138,7 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `input`. public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) - => tf.Context.RunInAutoMode2("Mean", name, new AutoModeArgs + => tf.Context.ExecuteOp("Mean", name, new AutoModeArgs { OpInputArgs = new { input, axis }, OpAttrs = new { keep_dims, reduction_indices = axis }, @@ -318,13 +315,10 @@ namespace Tensorflow /// Specifically, y = 1 / (1 + exp(-x)). /// public static Tensor sigmoid(Tensor x, string name = "Sigmoid") - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sigmoid", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Sigmoid", name, new AutoModeArgs + { + OpInputArgs = new { x } + }); /// /// Computes the gradient of the sigmoid of x wrt its input. @@ -344,7 +338,7 @@ namespace Tensorflow /// dy is the corresponding input gradient. /// public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad") - => tf.Context.RunInAutoMode2("SigmoidGrad", name, new AutoModeArgs + => tf.Context.ExecuteOp("SigmoidGrad", name, new AutoModeArgs { OpInputArgs = new { y, dy } }); @@ -576,13 +570,10 @@ namespace Tensorflow } public static Tensor log1p(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Log1p", name: name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Log1p", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Log1p", name, new AutoModeArgs + { + OpInputArgs = new { x } + }); public static Tensor logical_and(Tensor x, Tensor y, string name = null) => tf.OpDefLib._apply_op_helper("LogicalAnd", name, args: new { x, y }); @@ -691,13 +682,10 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `x`. public static Tensor exp(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Exp", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Exp", name, new AutoModeArgs + { + OpInputArgs = new { x } + }); /// /// Computes natural logarithm of x element-wise. @@ -739,14 +727,11 @@ namespace Tensorflow } public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Cast", name, - null, - x, - "DstT", DstT, "Truncate", Truncate).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Cast", name, new AutoModeArgs + { + OpInputArgs = new { x }, + OpAttrs = new { DstT, Truncate } + }); public static Tensor neg(Tensor x, string name = null) { @@ -783,7 +768,7 @@ namespace Tensorflow } public static Tensor sub(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode2("Sub", name, new AutoModeArgs + => tf.Context.ExecuteOp("Sub", name, new AutoModeArgs { OpInputArgs = new { x, y } }); @@ -1087,14 +1072,17 @@ namespace Tensorflow } public static Tensor _max(Tx input, Ty axis, bool keep_dims = false, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Max", name, - null, - input, axis, - "keep_dims", keep_dims).FirstOrDefault(), - input as Tensor); + => tf.Context.ExecuteOp("Max", name, new AutoModeArgs + { + OpInputArgs = new { input, axis }, + OpAttrs = new { keep_dims, reduction_indices = axis }, + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + align_corners = op.get_attr("align_corners"), + half_pixel_centers = op.get_attr("half_pixel_centers") + } + }); public static Tensor _min(Tx input, Ty axis, bool keep_dims = false, string name = null) { @@ -1170,13 +1158,10 @@ namespace Tensorflow /// /// public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Range", name, - null, - start, limit, delta).FirstOrDefault(), - start, limit, delta); + => tf.Context.ExecuteOp("Range", name, new AutoModeArgs + { + OpInputArgs = new { start, limit, delta } + }); /// /// Rounds the values of a tensor to the nearest integer, element-wise. diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index fc8a28d5..b26d84d6 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -45,7 +45,7 @@ namespace Tensorflow => gen_math_ops.add(x, y, name); public static Tensor add_v2(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode2("AddV2", name, new AutoModeArgs + => tf.Context.ExecuteOp("AddV2", name, new AutoModeArgs { OpInputArgs = new { x, y } }); @@ -261,7 +261,7 @@ namespace Tensorflow /// /// public static Tensor erf(Tensor x, string name = null) - => tf.Context.RunInAutoMode2("Erf", name, new AutoModeArgs + => tf.Context.ExecuteOp("Erf", name, new AutoModeArgs { OpInputArgs = new { x } }); @@ -270,7 +270,7 @@ namespace Tensorflow => gen_math_ops.sqrt(x, name: name); public static Tensor multiply(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode2("Mul", name, new AutoModeArgs + => tf.Context.ExecuteOp("Mul", name, new AutoModeArgs { OpInputArgs = new { x, y } });