| @@ -28,7 +28,7 @@ namespace Tensorflow | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="conjugate"></param> | /// <param name="conjugate"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) | |||||
| public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false) | |||||
| => array_ops.transpose(a, perm, name, conjugate); | => array_ops.transpose(a, perm, name, conjugate); | ||||
| public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) | public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) | ||||
| @@ -10,5 +10,21 @@ namespace Tensorflow.Gradients | |||||
| { | { | ||||
| return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | ||||
| } | } | ||||
| public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| return new Tensor[] { _ReshapeToInput(op, grads[0]) }; | |||||
| } | |||||
| private static Tensor _ReshapeToInput(Operation op, Tensor grad) | |||||
| { | |||||
| return array_ops.reshape(grad, array_ops.shape(op.inputs[0])); | |||||
| } | |||||
| public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| var p = op.inputs[1]; | |||||
| return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| public class control_flow_grad | |||||
| { | |||||
| public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| var grad = grads[0]; | |||||
| var _ = grads[1]; | |||||
| var input_op = op.inputs[0].op; | |||||
| var graph = ops.get_default_graph(); | |||||
| var op_ctxt = control_flow_util.GetOutputContext(input_op); | |||||
| var pred = op_ctxt.pred; | |||||
| var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||||
| return new Tensor[] { results.Item1, results.Item2 }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -93,7 +93,7 @@ namespace Tensorflow.Gradients | |||||
| var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64), | var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64), | ||||
| array_ops.size(in_shape) - 1); | array_ops.size(in_shape) - 1); | ||||
| var outerdim = array_ops.shape(ind_2d); | |||||
| var outerdim = array_ops.shape(ind_2d)[0]; | |||||
| // Compute linear indices(flattened to 1D). | // Compute linear indices(flattened to 1D). | ||||
| var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64); | var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64); | ||||
| @@ -102,7 +102,17 @@ namespace Tensorflow.Gradients | |||||
| var cast2 = math_ops.cast(dim2, TF_DataType.TF_INT32); | var cast2 = math_ops.cast(dim2, TF_DataType.TF_INT32); | ||||
| var ind = array_ops.reshape(ind_2d + cast2, new int[] { -1 }); | var ind = array_ops.reshape(ind_2d + cast2, new int[] { -1 }); | ||||
| throw new NotImplementedException("nn_grad._TopKGrad"); | |||||
| // Substitute grad to appropriate locations and fill the rest with zeros, | |||||
| // finally reshaping it to the original input shape. | |||||
| var scatter = gen_array_ops.scatter_nd(array_ops.expand_dims(ind, -1), | |||||
| array_ops.reshape(grad, new int[] { -1 }), | |||||
| new Tensor[] { math_ops.reduce_prod(in_shape) }); | |||||
| return new Tensor[] | |||||
| { | |||||
| array_ops.reshape(scatter, in_shape), | |||||
| array_ops.zeros(new int[0], dtype: TF_DataType.TF_INT32) | |||||
| }; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,6 +26,8 @@ namespace Tensorflow | |||||
| return math_grad._IdGrad(oper, out_grads); | return math_grad._IdGrad(oper, out_grads); | ||||
| case "MatMul": | case "MatMul": | ||||
| return math_grad._MatMulGrad(oper, out_grads); | return math_grad._MatMulGrad(oper, out_grads); | ||||
| case "Merge": | |||||
| return control_flow_grad._MergeGrad(oper, out_grads); | |||||
| case "Mul": | case "Mul": | ||||
| return math_grad._MulGrad(oper, out_grads); | return math_grad._MulGrad(oper, out_grads); | ||||
| case "Mean": | case "Mean": | ||||
| @@ -42,8 +44,12 @@ namespace Tensorflow | |||||
| return array_grad._ReshapeGrad(oper, out_grads); | return array_grad._ReshapeGrad(oper, out_grads); | ||||
| case "Relu": | case "Relu": | ||||
| return nn_grad._ReluGrad(oper, out_grads); | return nn_grad._ReluGrad(oper, out_grads); | ||||
| case "Squeeze": | |||||
| return array_grad._SqueezeGrad(oper, out_grads); | |||||
| case "SoftmaxCrossEntropyWithLogits": | case "SoftmaxCrossEntropyWithLogits": | ||||
| return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); | return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); | ||||
| case "Transpose": | |||||
| return array_grad._TransposeGrad(oper, out_grads); | |||||
| case "TopK": | case "TopK": | ||||
| case "TopKV2": | case "TopKV2": | ||||
| return nn_grad._TopKGrad(oper, out_grads); | return nn_grad._TopKGrad(oper, out_grads); | ||||
| @@ -10,22 +10,28 @@ namespace Tensorflow.Operations | |||||
| public class CondContext : ControlFlowContext | public class CondContext : ControlFlowContext | ||||
| { | { | ||||
| private string _name; | private string _name; | ||||
| /// <summary> | /// <summary> | ||||
| /// The boolean tensor for the cond predicate | /// The boolean tensor for the cond predicate | ||||
| /// </summary> | /// </summary> | ||||
| private Tensor _pred; | private Tensor _pred; | ||||
| public Tensor pred => _pred; | |||||
| /// <summary> | /// <summary> | ||||
| /// The predicate tensor in this branch | /// The predicate tensor in this branch | ||||
| /// </summary> | /// </summary> | ||||
| private Tensor _pivot; | private Tensor _pivot; | ||||
| /// <summary> | /// <summary> | ||||
| /// 0 or 1 representing this branch | /// 0 or 1 representing this branch | ||||
| /// </summary> | /// </summary> | ||||
| private int _branch; | private int _branch; | ||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| private List<string> _values = new List<string>(); | private List<string> _values = new List<string>(); | ||||
| private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -32,5 +32,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| _control_flow_context = ctx; | _control_flow_context = ctx; | ||||
| } | } | ||||
| public CondContext _get_control_flow_context() | |||||
| { | |||||
| return _control_flow_context; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,7 +7,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class array_ops : Python | public class array_ops : Python | ||||
| { | { | ||||
| public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name); | |||||
| public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = null) | |||||
| => gen_array_ops.placeholder_with_default(input, shape, name); | |||||
| public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| { | { | ||||
| @@ -111,14 +112,14 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => expand_dims_v2(input, axis, name); | |||||
| public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) | |||||
| => expand_dims_v2(input, axis, name); | |||||
| private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name); | |||||
| private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) | |||||
| => gen_array_ops.expand_dims(input, axis, name); | |||||
| public static Tensor rank(Tensor input, string name = null) | public static Tensor rank(Tensor input, string name = null) | ||||
| { | |||||
| return math_ops.rank_internal(input, name, optimize: true); | |||||
| } | |||||
| => math_ops.rank_internal(input, name, optimize: true); | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a tensor with all elements set to 1. | /// Creates a tensor with all elements set to 1. | ||||
| @@ -132,9 +133,7 @@ namespace Tensorflow | |||||
| => ones_like_impl(tensor, dtype, name, optimize); | => ones_like_impl(tensor, dtype, name, optimize); | ||||
| public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null) | public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null) | ||||
| { | |||||
| return gen_array_ops.reshape(tensor, shape, null); | |||||
| } | |||||
| => gen_array_ops.reshape(tensor, shape, null); | |||||
| private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | ||||
| { | { | ||||
| @@ -239,14 +238,10 @@ namespace Tensorflow | |||||
| /// </param> | /// </param> | ||||
| /// <returns>A `Tensor` of type `out_type`.</returns> | /// <returns>A `Tensor` of type `out_type`.</returns> | ||||
| public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | ||||
| { | |||||
| return shape_internal(input, name, optimize: true, out_type: out_type); | |||||
| } | |||||
| => shape_internal(input, name, optimize: true, out_type: out_type); | |||||
| public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | ||||
| { | |||||
| return size_internal(input, name, optimize: optimize, out_type: out_type); | |||||
| } | |||||
| => size_internal(input, name, optimize: optimize, out_type: out_type); | |||||
| private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | ||||
| { | { | ||||
| @@ -323,8 +318,46 @@ namespace Tensorflow | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor stop_gradient(Tensor input, string name = null) | public static Tensor stop_gradient(Tensor input, string name = null) | ||||
| => gen_array_ops.stop_gradient(input, name); | |||||
| /// <summary> | |||||
| /// Extracts a strided slice of a tensor (generalized python array indexing). | |||||
| /// </summary> | |||||
| /// <param name="input_"></param> | |||||
| /// <param name="begin"></param> | |||||
| /// <param name="end"></param> | |||||
| /// <param name="strides"></param> | |||||
| /// <param name="begin_mask"></param> | |||||
| /// <param name="end_mask"></param> | |||||
| /// <param name="ellipsis_mask"></param> | |||||
| /// <param name="new_axis_mask"></param> | |||||
| /// <param name="shrink_axis_mask"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end, | |||||
| Tensor strides = null, | |||||
| int begin_mask = 0, | |||||
| int end_mask = 0, | |||||
| int ellipsis_mask = 0, | |||||
| int new_axis_mask = 0, | |||||
| int shrink_axis_mask = 0, | |||||
| string name = null) | |||||
| { | { | ||||
| return gen_array_ops.stop_gradient(input, name); | |||||
| var op = gen_array_ops.strided_slice( | |||||
| input: input_, | |||||
| begin: begin, | |||||
| end: end, | |||||
| strides: strides, | |||||
| begin_mask: begin_mask, | |||||
| end_mask: end_mask, | |||||
| ellipsis_mask: ellipsis_mask, | |||||
| new_axis_mask: new_axis_mask, | |||||
| shrink_axis_mask: shrink_axis_mask, | |||||
| name: name); | |||||
| string parent_name = name; | |||||
| return op; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -345,14 +378,14 @@ namespace Tensorflow | |||||
| /// Contains the same data as `input`, but has one or more dimensions of | /// Contains the same data as `input`, but has one or more dimensions of | ||||
| /// size 1 removed.</returns> | /// size 1 removed.</returns> | ||||
| public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int[] squeeze_dims = null) | public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int[] squeeze_dims = null) | ||||
| { | |||||
| return gen_array_ops.squeeze(input, axis, name); | |||||
| } | |||||
| => gen_array_ops.squeeze(input, axis, name); | |||||
| public static Tensor identity(Tensor input, string name = null) | public static Tensor identity(Tensor input, string name = null) | ||||
| { | |||||
| return gen_array_ops.identity(input, name); | |||||
| } | |||||
| => gen_array_ops.identity(input, name); | |||||
| public static Tensor invert_permutation(Tensor x, string name = null) | |||||
| => gen_array_ops.invert_permutation(x, name: name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the shape of a broadcast given symbolic shapes. | /// Computes the shape of a broadcast given symbolic shapes. | ||||
| /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of | /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of | ||||
| @@ -368,26 +401,19 @@ namespace Tensorflow | |||||
| /// <param name="shape_y"> A rank 1 integer `Tensor`, representing the shape of y.</param> | /// <param name="shape_y"> A rank 1 integer `Tensor`, representing the shape of y.</param> | ||||
| /// <returns> A rank 1 integer `Tensor` representing the broadcasted shape.</returns> | /// <returns> A rank 1 integer `Tensor` representing the broadcasted shape.</returns> | ||||
| public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y) | public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y) | ||||
| { | |||||
| return gen_array_ops.broadcast_args(shape_x, shape_y); | |||||
| } | |||||
| => gen_array_ops.broadcast_args(shape_x, shape_y); | |||||
| public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) | public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) | ||||
| { | |||||
| return Framework.common_shapes.broadcast_shape(shape_x, shape_y); | |||||
| } | |||||
| => Framework.common_shapes.broadcast_shape(shape_x, shape_y); | |||||
| public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | ||||
| { | |||||
| return gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||||
| } | |||||
| => gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||||
| public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) | |||||
| public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false) | |||||
| { | { | ||||
| return with(ops.name_scope(name, "transpose", new { a }), scope => | return with(ops.name_scope(name, "transpose", new { a }), scope => | ||||
| { | { | ||||
| name = scope; | |||||
| return gen_array_ops.transpose(a, perm, name); | |||||
| return gen_array_ops.transpose(a, perm, name: scope); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -138,6 +138,22 @@ namespace Tensorflow | |||||
| return gen_array_ops.identity(data, name: name); | return gen_array_ops.identity(data, name: name); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Forwards `data` to an output determined by `pred`. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="pred"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | |||||
| { | |||||
| data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | |||||
| ops.colocate_with(data, ignore_existing: true); | |||||
| return @switch(data, pred, name: name); | |||||
| } | |||||
| public static Tensor[] cond<T>(Tensor pred, | public static Tensor[] cond<T>(Tensor pred, | ||||
| Func<T[]> true_fn = null, | Func<T[]> true_fn = null, | ||||
| Func<T[]> false_fn = null, | Func<T[]> false_fn = null, | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -25,5 +26,12 @@ namespace Tensorflow | |||||
| { | { | ||||
| return op.type == "Switch" || op.type == "RefSwitch"; | return op.type == "Switch" || op.type == "RefSwitch"; | ||||
| } | } | ||||
| public static CondContext GetOutputContext(Operation op) | |||||
| { | |||||
| var ctxt = op._get_control_flow_context(); | |||||
| return ctxt; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -74,6 +74,13 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor invert_permutation(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("InvertPermutation", name, new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor log(Tensor x, string name = null) | public static Tensor log(Tensor x, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Log", name: name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Log", name: name, args: new { x }); | ||||
| @@ -163,6 +170,12 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("ScatterNd", name, new { indices, updates, shape }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); | var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); | ||||
| @@ -181,7 +194,7 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor transpose(Tensor x, int[] perm, string name = null) | |||||
| public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Transpose", name, new { x, perm }); | var _op = _op_def_lib._apply_op_helper("Transpose", name, new { x, perm }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -200,6 +213,30 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, | |||||
| int begin_mask = 0, | |||||
| int end_mask = 0, | |||||
| int ellipsis_mask = 0, | |||||
| int new_axis_mask = 0, | |||||
| int shrink_axis_mask = 0, | |||||
| string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new | |||||
| { | |||||
| input, | |||||
| begin, | |||||
| end, | |||||
| strides, | |||||
| begin_mask, | |||||
| end_mask, | |||||
| ellipsis_mask, | |||||
| new_axis_mask, | |||||
| shrink_axis_mask | |||||
| }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | ||||
| @@ -187,6 +187,61 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public Tensor this[int slice_spec] | |||||
| { | |||||
| get | |||||
| { | |||||
| var slice_spec_s = new int[] { slice_spec }; | |||||
| var begin = new List<int>(); | |||||
| var end = new List<int>(); | |||||
| var strides = new List<int>(); | |||||
| var index = 0; | |||||
| var (new_axis_mask, shrink_axis_mask) = (0, 0); | |||||
| var (begin_mask, end_mask) = (0, 0); | |||||
| var ellipsis_mask = 0; | |||||
| foreach(var s in slice_spec_s) | |||||
| { | |||||
| { | |||||
| begin.Add(s); | |||||
| end.Add(s + 1); | |||||
| strides.Add(1); | |||||
| shrink_axis_mask |= (1 << index); | |||||
| } | |||||
| index += 1; | |||||
| } | |||||
| return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| { | |||||
| string name = scope; | |||||
| if(begin != null) | |||||
| { | |||||
| var (packed_begin, packed_end, packed_strides) = | |||||
| (array_ops.stack(begin.ToArray()), | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | |||||
| this, | |||||
| packed_begin, | |||||
| packed_end, | |||||
| packed_strides, | |||||
| begin_mask: begin_mask, | |||||
| end_mask: end_mask, | |||||
| shrink_axis_mask: shrink_axis_mask, | |||||
| new_axis_mask: new_axis_mask, | |||||
| ellipsis_mask: ellipsis_mask, | |||||
| name: name); | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| }); | |||||
| } | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| if(NDims == 0) | if(NDims == 0) | ||||