| @@ -30,49 +30,19 @@ namespace Tensorflow.Contexts | |||
| public sealed partial class Context | |||
| { | |||
| // [DebuggerStepThrough] | |||
| public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args) | |||
| { | |||
| if (tf.Context.has_graph_arg(args)) | |||
| { | |||
| if (executing_eagerly()) | |||
| { | |||
| graph_mode(); | |||
| var result = graphAction(); | |||
| restore_mode(); | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| return graphAction(); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return eagerAction(); | |||
| } | |||
| else | |||
| { | |||
| return graphAction(); | |||
| } | |||
| } | |||
| } | |||
| // [DebuggerStepThrough] | |||
| public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args) | |||
| public Tensors ExecuteOp(string OpType, string Name, AutoModeArgs args) | |||
| { | |||
| var inputArgs = ConvertToDict(args.OpInputArgs); | |||
| var attrDict = ConvertToDict(args.OpAttrs); | |||
| Func<Tensor> graphAction = () => | |||
| Func<Tensors> graphAction = () => | |||
| { | |||
| foreach (var attr in attrDict) | |||
| inputArgs[attr.Key] = attr.Value; | |||
| return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output; | |||
| return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).outputs; | |||
| }; | |||
| Func<Tensor> eagerAction = () => | |||
| Func<Tensors> eagerAction = () => | |||
| { | |||
| var attrs = new object[attrDict.Count() * 2]; | |||
| int i = 0; | |||
| @@ -87,7 +57,7 @@ namespace Tensorflow.Contexts | |||
| OpType, Name, | |||
| null, | |||
| inputArgs.Values.ToArray(), | |||
| attrs).FirstOrDefault(); | |||
| attrs); | |||
| }; | |||
| if (tf.Context.has_graph_arg(inputArgs.Values)) | |||
| @@ -269,29 +269,24 @@ namespace Tensorflow.Operations | |||
| } | |||
| public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name, | |||
| args: new | |||
| { | |||
| y_backprop = @params.YBackprop, | |||
| x = @params.X, | |||
| scale = @params.Scale, | |||
| reserve_space_1 = @params.ReserveSpace1, | |||
| reserve_space_2 = @params.ReserveSpace2, | |||
| reserve_space_3 = @params.ReserveSpace3, | |||
| epsilon = @params.Epsilon, | |||
| data_format = @params.DataFormat, | |||
| is_training = @params.IsTraining | |||
| }).outputs, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "FusedBatchNormGradV3", @params.Name, | |||
| null, | |||
| @params.YBackprop, @params.X, @params.Scale, | |||
| @params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3, | |||
| "epsilon", @params.Epsilon, | |||
| "data_format", @params.DataFormat, | |||
| "is_training", @params.IsTraining), | |||
| @params.YBackprop); | |||
| => tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new | |||
| { | |||
| y_backprop = @params.YBackprop, | |||
| x = @params.X, | |||
| scale = @params.Scale, | |||
| reserve_space_1 = @params.ReserveSpace1, | |||
| reserve_space_2 = @params.ReserveSpace2, | |||
| reserve_space_3 = @params.ReserveSpace3 | |||
| }, | |||
| OpAttrs = new | |||
| { | |||
| epsilon = @params.Epsilon, | |||
| data_format = @params.DataFormat, | |||
| is_training = @params.IsTraining | |||
| } | |||
| }); | |||
| public static Tensor[] fused_batch_norm(Tensor x, | |||
| Tensor scale, | |||
| @@ -388,14 +383,10 @@ namespace Tensorflow.Operations | |||
| } | |||
| public static Tensor log_softmax(Tensor logits, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("LogSoftmax", name: name, | |||
| args: new { logits }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "LogSoftmax", name, | |||
| null, | |||
| logits).FirstOrDefault(), | |||
| logits); | |||
| => tf.Context.ExecuteOp("LogSoftmax", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { logits } | |||
| }); | |||
| /// <summary> | |||
| /// Says whether the targets are in the top `K` predictions. | |||
| @@ -418,19 +409,11 @@ namespace Tensorflow.Operations | |||
| } | |||
| public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("LeakyRelu", name: name, | |||
| args: new | |||
| { | |||
| features, | |||
| alpha | |||
| }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "LeakyRelu", name, | |||
| null, | |||
| features, | |||
| "alpha", alpha).FirstOrDefault(), | |||
| features); | |||
| => tf.Context.ExecuteOp("LeakyRelu", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { features }, | |||
| OpAttrs = new { alpha } | |||
| }); | |||
| public static Tensor max_pool(Tensor input, | |||
| int[] ksize, | |||
| @@ -737,7 +737,7 @@ namespace Tensorflow | |||
| public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | |||
| long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0, | |||
| long shrink_axis_mask = 0, string name = null) | |||
| => tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("StridedSliceGrad", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new | |||
| { | |||
| @@ -960,7 +960,7 @@ namespace Tensorflow | |||
| => gen_array_ops.slice(input, begin, size, name: name); | |||
| public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) | |||
| => tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("Slice", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { input, begin, size }, | |||
| GetGradientAttrs = (op) => new | |||
| @@ -72,14 +72,10 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor concat_v2(Tensor[] values, int axis, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("ConcatV2", name: name, | |||
| args: new { values, axis }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "ConcatV2", name, | |||
| null, | |||
| values, axis).FirstOrDefault(), | |||
| values); | |||
| => tf.Context.ExecuteOp("ConcatV2", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { values, axis } | |||
| }); | |||
| private static Tensor concat_v2_eager_fallback<T1, T2>(T1[] values, T2 axis, string name, Context ctx) | |||
| { | |||
| @@ -202,14 +198,11 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor pack(Tensor[] values, int axis = 0, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Pack", name, new { values, axis }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Pack", name, | |||
| null, | |||
| values, | |||
| "axis", axis).FirstOrDefault(), | |||
| values, axis); | |||
| => tf.Context.ExecuteOp("Pack", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { values }, | |||
| OpAttrs = new { axis } | |||
| }); | |||
| /// <summary> | |||
| /// Return a tensor with the same shape and contents as the input tensor or value. | |||
| @@ -326,31 +319,16 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor reshape<T>(Tensor tensor, T shape, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Reshape", name, | |||
| null, | |||
| tensor, shape).FirstOrDefault(), | |||
| tensor, shape); | |||
| => tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { tensor, shape } | |||
| }); | |||
| public static Tensor reshape(Tensor tensor, object[] shape, string name = null) | |||
| { | |||
| try | |||
| => tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs | |||
| { | |||
| return tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Reshape", name, | |||
| null, | |||
| tensor, shape).FirstOrDefault(), | |||
| tensor, shape); | |||
| } | |||
| catch (InvalidArgumentError ex) | |||
| { | |||
| return reshape_eager_fallback(tensor, shape, name, tf.Context); | |||
| } | |||
| } | |||
| OpInputArgs = new { tensor, shape } | |||
| }); | |||
| private static Tensor reshape_eager_fallback(Tensor tensor, object[] shape, string name, Context ctx) | |||
| { | |||
| @@ -467,15 +445,11 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Shape", name, | |||
| new { input, out_type }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Shape", name, | |||
| null, | |||
| input, | |||
| "out_type", out_type).FirstOrDefault(), | |||
| input); | |||
| => tf.Context.ExecuteOp("Shape", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { input }, | |||
| OpAttrs = new { out_type } | |||
| }); | |||
| /// <summary> | |||
| /// Returns shape of tensors. | |||
| @@ -559,22 +533,16 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor tile(Tensor input, Tensor multiples, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Tile", name, | |||
| null, | |||
| input, multiples).FirstOrDefault(), | |||
| input, multiples); | |||
| => tf.Context.ExecuteOp("Tile", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { input, multiples } | |||
| }); | |||
| public static Tensor tile(Tensor input, object[] multiples, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Tile", name, | |||
| null, | |||
| input, multiples).FirstOrDefault(), | |||
| input, multiples); | |||
| => tf.Context.ExecuteOp("Tile", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { input, multiples } | |||
| }); | |||
| public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null) | |||
| { | |||
| @@ -592,22 +560,16 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor ones_like(Tensor x, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "OnesLike", name, | |||
| null, | |||
| x).FirstOrDefault(), | |||
| x); | |||
| => tf.Context.ExecuteOp("OnesLike", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x } | |||
| }); | |||
| public static Tensor zeros_like(Tensor x, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "ZerosLike", name, | |||
| null, | |||
| x).FirstOrDefault(), | |||
| x); | |||
| => tf.Context.ExecuteOp("ZerosLike", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x } | |||
| }); | |||
| public static Tensor stop_gradient(Tensor x, string name = null) | |||
| { | |||
| @@ -623,53 +585,37 @@ namespace Tensorflow | |||
| long new_axis_mask = 0, | |||
| long shrink_axis_mask = 0, | |||
| string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("StridedSlice", name, new | |||
| => tf.Context.ExecuteOp("StridedSlice", name, new AutoModeArgs | |||
| { | |||
| input, | |||
| begin, | |||
| end, | |||
| strides, | |||
| begin_mask, | |||
| end_mask, | |||
| ellipsis_mask, | |||
| new_axis_mask, | |||
| shrink_axis_mask | |||
| }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "StridedSlice", name, | |||
| null, | |||
| input, begin, end, strides, | |||
| "begin_mask", begin_mask, | |||
| "end_mask", end_mask, | |||
| "ellipsis_mask", ellipsis_mask, | |||
| "new_axis_mask", new_axis_mask, | |||
| "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
| input, begin, end, strides); | |||
| public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, | |||
| OpInputArgs = new { input, begin, end, strides }, | |||
| OpAttrs = new | |||
| { | |||
| begin_mask, | |||
| end_mask, | |||
| ellipsis_mask, | |||
| new_axis_mask, | |||
| shrink_axis_mask | |||
| } | |||
| }); | |||
| public static Tensor resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, | |||
| int begin_mask = 0, | |||
| int end_mask = 0, | |||
| int ellipsis_mask = 0, | |||
| int new_axis_mask = 0, | |||
| int shrink_axis_mask = 0, | |||
| string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new | |||
| => tf.Context.ExecuteOp("ResourceStridedSliceAssign", name, new AutoModeArgs | |||
| { | |||
| input, begin, end, strides, value, | |||
| begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask | |||
| }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "ResourceStridedSliceAssign", name, | |||
| null, | |||
| input, begin, end, strides, value, | |||
| "begin_mask", begin_mask, | |||
| "end_mask", end_mask, | |||
| "ellipsis_mask", ellipsis_mask, | |||
| "new_axis_mask", new_axis_mask, | |||
| "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
| input, begin, end, strides, value); | |||
| OpInputArgs = new { input, begin, end, strides, value }, | |||
| OpAttrs = new { | |||
| begin_mask, | |||
| end_mask, | |||
| ellipsis_mask, | |||
| new_axis_mask, | |||
| shrink_axis_mask | |||
| } | |||
| }); | |||
| public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides, | |||
| int begin_mask = 0, | |||
| @@ -222,25 +222,15 @@ namespace Tensorflow | |||
| public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, bool align_corners = false, | |||
| bool half_pixel_centers = false, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("ResizeNearestNeighbor", name: name, args: new | |||
| => tf.Context.ExecuteOp("ResizeNearestNeighbor", name, new AutoModeArgs | |||
| { | |||
| images, | |||
| size, | |||
| align_corners, | |||
| half_pixel_centers | |||
| }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "ResizeNearestNeighbor", name, | |||
| null, | |||
| images, size, | |||
| "align_corners", align_corners, | |||
| "half_pixel_centers", half_pixel_centers).FirstOrDefault(), | |||
| images); | |||
| OpInputArgs = new { images, size }, | |||
| OpAttrs = new { align_corners, half_pixel_centers } | |||
| }); | |||
| public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, | |||
| bool half_pixel_centers = false, string name = null) | |||
| => tf.Context.RunInAutoMode2("ResizeNearestNeighborGrad", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("ResizeNearestNeighborGrad", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { grads, size }, | |||
| OpAttrs = new { align_corners, half_pixel_centers }, | |||
| @@ -116,13 +116,10 @@ namespace Tensorflow | |||
| /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) | |||
| /// </remarks> | |||
| public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "DivNoNan", name, | |||
| null, | |||
| x, y).FirstOrDefault(), | |||
| x, y); | |||
| => tf.Context.ExecuteOp("DivNoNan", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x, y } | |||
| }); | |||
| public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) | |||
| => mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); | |||
| @@ -141,7 +138,7 @@ namespace Tensorflow | |||
| /// <param name="name"> A name for the operation (optional).</param> | |||
| /// <returns> A `Tensor`. Has the same type as `input`.</returns> | |||
| public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) | |||
| => tf.Context.RunInAutoMode2("Mean", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("Mean", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { input, axis }, | |||
| OpAttrs = new { keep_dims, reduction_indices = axis }, | |||
| @@ -318,13 +315,10 @@ namespace Tensorflow | |||
| /// Specifically, <c>y = 1 / (1 + exp(-x))</c>. | |||
| /// </remarks> | |||
| public static Tensor sigmoid(Tensor x, string name = "Sigmoid") | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Sigmoid", name, | |||
| null, | |||
| x).FirstOrDefault(), | |||
| x); | |||
| => tf.Context.ExecuteOp("Sigmoid", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x } | |||
| }); | |||
| /// <summary> | |||
| /// Computes the gradient of the sigmoid of <c>x</c> wrt its input. | |||
| @@ -344,7 +338,7 @@ namespace Tensorflow | |||
| /// <c>dy</c> is the corresponding input gradient. | |||
| /// </remarks> | |||
| public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad") | |||
| => tf.Context.RunInAutoMode2("SigmoidGrad", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("SigmoidGrad", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { y, dy } | |||
| }); | |||
| @@ -576,13 +570,10 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor log1p(Tensor x, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Log1p", name: name, new { x }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Log1p", name, | |||
| null, | |||
| x).FirstOrDefault(), | |||
| x); | |||
| => tf.Context.ExecuteOp("Log1p", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x } | |||
| }); | |||
| public static Tensor logical_and(Tensor x, Tensor y, string name = null) | |||
| => tf.OpDefLib._apply_op_helper("LogicalAnd", name, args: new { x, y }); | |||
| @@ -691,13 +682,10 @@ namespace Tensorflow | |||
| /// <param name="name"> A name for the operation (optional).</param> | |||
| /// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||
| public static Tensor exp(Tensor x, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Exp", name, | |||
| null, | |||
| x).FirstOrDefault(), | |||
| x); | |||
| => tf.Context.ExecuteOp("Exp", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x } | |||
| }); | |||
| /// <summary> | |||
| /// Computes natural logarithm of x element-wise. | |||
| @@ -739,14 +727,11 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Cast", name, | |||
| null, | |||
| x, | |||
| "DstT", DstT, "Truncate", Truncate).FirstOrDefault(), | |||
| x); | |||
| => tf.Context.ExecuteOp("Cast", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x }, | |||
| OpAttrs = new { DstT, Truncate } | |||
| }); | |||
| public static Tensor neg(Tensor x, string name = null) | |||
| { | |||
| @@ -783,7 +768,7 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor sub(Tensor x, Tensor y, string name = null) | |||
| => tf.Context.RunInAutoMode2("Sub", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("Sub", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x, y } | |||
| }); | |||
| @@ -1087,14 +1072,17 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Max", name, | |||
| null, | |||
| input, axis, | |||
| "keep_dims", keep_dims).FirstOrDefault(), | |||
| input as Tensor); | |||
| => tf.Context.ExecuteOp("Max", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { input, axis }, | |||
| OpAttrs = new { keep_dims, reduction_indices = axis }, | |||
| GetGradientAttrs = (op) => new | |||
| { | |||
| T = op.get_attr<TF_DataType>("T"), | |||
| align_corners = op.get_attr<bool>("align_corners"), | |||
| half_pixel_centers = op.get_attr<bool>("half_pixel_centers") | |||
| } | |||
| }); | |||
| public static Tensor _min<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | |||
| { | |||
| @@ -1170,13 +1158,10 @@ namespace Tensorflow | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Range", name, | |||
| null, | |||
| start, limit, delta).FirstOrDefault(), | |||
| start, limit, delta); | |||
| => tf.Context.ExecuteOp("Range", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { start, limit, delta } | |||
| }); | |||
| /// <summary> | |||
| /// Rounds the values of a tensor to the nearest integer, element-wise. | |||
| @@ -45,7 +45,7 @@ namespace Tensorflow | |||
| => gen_math_ops.add(x, y, name); | |||
| public static Tensor add_v2(Tensor x, Tensor y, string name = null) | |||
| => tf.Context.RunInAutoMode2("AddV2", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("AddV2", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x, y } | |||
| }); | |||
| @@ -261,7 +261,7 @@ namespace Tensorflow | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor erf(Tensor x, string name = null) | |||
| => tf.Context.RunInAutoMode2("Erf", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("Erf", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x } | |||
| }); | |||
| @@ -270,7 +270,7 @@ namespace Tensorflow | |||
| => gen_math_ops.sqrt(x, name: name); | |||
| public static Tensor multiply(Tensor x, Tensor y, string name = null) | |||
| => tf.Context.RunInAutoMode2("Mul", name, new AutoModeArgs | |||
| => tf.Context.ExecuteOp("Mul", name, new AutoModeArgs | |||
| { | |||
| OpInputArgs = new { x, y } | |||
| }); | |||