From 2e89fa7b76515e427347894b4358e936564affe0 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Fri, 15 Mar 2019 17:32:21 -0500 Subject: [PATCH] add strided_slice for Tensor, _SwitchRefOrTensor --- src/TensorFlowNET.Core/APIs/tf.array.cs | 2 +- .../Gradients/array_grad.py.cs | 16 ++++ .../Gradients/control_flow_grad.py.cs | 22 +++++ .../Gradients/nn_grad.py.cs | 14 ++- .../ops.gradient_function_mapping.cs | 6 ++ .../Operations/ControlFlows/CondContext.cs | 6 ++ .../Operations/Operation.Control.cs | 5 + .../Operations/array_ops.py.cs | 94 ++++++++++++------- .../Operations/control_flow_ops.py.cs | 16 ++++ .../Operations/control_flow_util.py.cs | 8 ++ .../Operations/gen_array_ops.cs | 39 +++++++- src/TensorFlowNET.Core/Tensors/Tensor.cs | 55 +++++++++++ 12 files changed, 245 insertions(+), 38 deletions(-) create mode 100644 src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 91b59842..a0fb3c72 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -28,7 +28,7 @@ namespace Tensorflow /// /// /// - public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) + public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false) => array_ops.transpose(a, perm, name, conjugate); public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.py.cs b/src/TensorFlowNET.Core/Gradients/array_grad.py.cs index 64e2a9ec..cdd319ea 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.py.cs @@ -10,5 +10,21 @@ namespace Tensorflow.Gradients { return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; } + + public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { _ReshapeToInput(op, grads[0]) }; + } + + private static Tensor _ReshapeToInput(Operation op, Tensor grad) + { + return array_ops.reshape(grad, array_ops.shape(op.inputs[0])); + } + + public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads) + { + var p = op.inputs[1]; + return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; + } } } diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs new file mode 100644 index 00000000..5e27f38c --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Gradients +{ + public class control_flow_grad + { + public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var _ = grads[1]; + var input_op = op.inputs[0].op; + var graph = ops.get_default_graph(); + var op_ctxt = control_flow_util.GetOutputContext(input_op); + var pred = op_ctxt.pred; + + var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); + return new Tensor[] { results.Item1, results.Item2 }; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs index 9f5e2391..a4840fd7 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs @@ -93,7 +93,7 @@ namespace Tensorflow.Gradients var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64), array_ops.size(in_shape) - 1); - var outerdim = array_ops.shape(ind_2d); + var outerdim = array_ops.shape(ind_2d)[0]; // Compute linear indices(flattened to 1D). var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64); @@ -102,7 +102,17 @@ namespace Tensorflow.Gradients var cast2 = math_ops.cast(dim2, TF_DataType.TF_INT32); var ind = array_ops.reshape(ind_2d + cast2, new int[] { -1 }); - throw new NotImplementedException("nn_grad._TopKGrad"); + // Substitute grad to appropriate locations and fill the rest with zeros, + // finally reshaping it to the original input shape. + var scatter = gen_array_ops.scatter_nd(array_ops.expand_dims(ind, -1), + array_ops.reshape(grad, new int[] { -1 }), + new Tensor[] { math_ops.reduce_prod(in_shape) }); + + return new Tensor[] + { + array_ops.reshape(scatter, in_shape), + array_ops.zeros(new int[0], dtype: TF_DataType.TF_INT32) + }; } } } diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index a19e0db2..9601077b 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -26,6 +26,8 @@ namespace Tensorflow return math_grad._IdGrad(oper, out_grads); case "MatMul": return math_grad._MatMulGrad(oper, out_grads); + case "Merge": + return control_flow_grad._MergeGrad(oper, out_grads); case "Mul": return math_grad._MulGrad(oper, out_grads); case "Mean": @@ -42,8 +44,12 @@ namespace Tensorflow return array_grad._ReshapeGrad(oper, out_grads); case "Relu": return nn_grad._ReluGrad(oper, out_grads); + case "Squeeze": + return array_grad._SqueezeGrad(oper, out_grads); case "SoftmaxCrossEntropyWithLogits": return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); + case "Transpose": + return array_grad._TransposeGrad(oper, out_grads); case "TopK": case "TopKV2": return nn_grad._TopKGrad(oper, out_grads); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 23799892..9516c42f 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -10,22 +10,28 @@ namespace Tensorflow.Operations public class CondContext : ControlFlowContext { private string _name; + /// /// The boolean tensor for the cond predicate /// private Tensor _pred; + public Tensor pred => _pred; + /// /// The predicate tensor in this branch /// private Tensor _pivot; + /// /// 0 or 1 representing this branch /// private int _branch; + /// /// /// private List _values = new List(); + private Dictionary _external_values = new Dictionary(); /// diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 74078e27..654e1b81 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -32,5 +32,10 @@ namespace Tensorflow { _control_flow_context = ctx; } + + public CondContext _get_control_flow_context() + { + return _control_flow_context; + } } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 39df4c37..dc793fea 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -7,7 +7,8 @@ namespace Tensorflow { public class array_ops : Python { - public static Tensor placeholder_with_default(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name); + public static Tensor placeholder_with_default(T input, int[] shape, string name = null) + => gen_array_ops.placeholder_with_default(input, shape, name); public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { @@ -111,14 +112,14 @@ namespace Tensorflow }); } - public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => expand_dims_v2(input, axis, name); + public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) + => expand_dims_v2(input, axis, name); - private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name); + private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) + => gen_array_ops.expand_dims(input, axis, name); public static Tensor rank(Tensor input, string name = null) - { - return math_ops.rank_internal(input, name, optimize: true); - } + => math_ops.rank_internal(input, name, optimize: true); /// /// Creates a tensor with all elements set to 1. @@ -132,9 +133,7 @@ namespace Tensorflow => ones_like_impl(tensor, dtype, name, optimize); public static Tensor reshape(T1 tensor, T2 shape, string name = null) - { - return gen_array_ops.reshape(tensor, shape, null); - } + => gen_array_ops.reshape(tensor, shape, null); private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) { @@ -239,14 +238,10 @@ namespace Tensorflow /// /// A `Tensor` of type `out_type`. public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) - { - return shape_internal(input, name, optimize: true, out_type: out_type); - } + => shape_internal(input, name, optimize: true, out_type: out_type); public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) - { - return size_internal(input, name, optimize: optimize, out_type: out_type); - } + => size_internal(input, name, optimize: optimize, out_type: out_type); private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) { @@ -323,8 +318,46 @@ namespace Tensorflow /// /// public static Tensor stop_gradient(Tensor input, string name = null) + => gen_array_ops.stop_gradient(input, name); + + /// + /// Extracts a strided slice of a tensor (generalized python array indexing). + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end, + Tensor strides = null, + int begin_mask = 0, + int end_mask = 0, + int ellipsis_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0, + string name = null) { - return gen_array_ops.stop_gradient(input, name); + var op = gen_array_ops.strided_slice( + input: input_, + begin: begin, + end: end, + strides: strides, + begin_mask: begin_mask, + end_mask: end_mask, + ellipsis_mask: ellipsis_mask, + new_axis_mask: new_axis_mask, + shrink_axis_mask: shrink_axis_mask, + name: name); + + string parent_name = name; + + return op; } /// @@ -345,14 +378,14 @@ namespace Tensorflow /// Contains the same data as `input`, but has one or more dimensions of /// size 1 removed. public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int[] squeeze_dims = null) - { - return gen_array_ops.squeeze(input, axis, name); - } + => gen_array_ops.squeeze(input, axis, name); public static Tensor identity(Tensor input, string name = null) - { - return gen_array_ops.identity(input, name); - } + => gen_array_ops.identity(input, name); + + public static Tensor invert_permutation(Tensor x, string name = null) + => gen_array_ops.invert_permutation(x, name: name); + /// /// Computes the shape of a broadcast given symbolic shapes. /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of @@ -368,26 +401,19 @@ namespace Tensorflow /// A rank 1 integer `Tensor`, representing the shape of y. /// A rank 1 integer `Tensor` representing the broadcasted shape. public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y) - { - return gen_array_ops.broadcast_args(shape_x, shape_y); - } + => gen_array_ops.broadcast_args(shape_x, shape_y); public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) - { - return Framework.common_shapes.broadcast_shape(shape_x, shape_y); - } + => Framework.common_shapes.broadcast_shape(shape_x, shape_y); public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) - { - return gen_array_ops.gather_v2(@params, indices, axis, name: name); - } + => gen_array_ops.gather_v2(@params, indices, axis, name: name); - public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) + public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false) { return with(ops.name_scope(name, "transpose", new { a }), scope => { - name = scope; - return gen_array_ops.transpose(a, perm, name); + return gen_array_ops.transpose(a, perm, name: scope); }); } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 4ea21ee6..03773c9f 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -138,6 +138,22 @@ namespace Tensorflow return gen_array_ops.identity(data, name: name); } + /// + /// Forwards `data` to an output determined by `pred`. + /// + /// + /// + /// + /// + public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") + { + data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); + + ops.colocate_with(data, ignore_existing: true); + + return @switch(data, pred, name: name); + } + public static Tensor[] cond(Tensor pred, Func true_fn = null, Func false_fn = null, diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index 1b8a304c..146b681c 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations; namespace Tensorflow { @@ -25,5 +26,12 @@ namespace Tensorflow { return op.type == "Switch" || op.type == "RefSwitch"; } + + public static CondContext GetOutputContext(Operation op) + { + var ctxt = op._get_control_flow_context(); + + return ctxt; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index ae6e1f09..1d193a8a 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -74,6 +74,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor invert_permutation(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("InvertPermutation", name, new { x }); + + return _op.outputs[0]; + } + public static Tensor log(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("Log", name: name, args: new { x }); @@ -163,6 +170,12 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ScatterNd", name, new { indices, updates, shape }); + return _op.outputs[0]; + } + public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) { var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); @@ -181,7 +194,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor transpose(Tensor x, int[] perm, string name = null) + public static Tensor transpose(T1 x, T2 perm, string name = null) { var _op = _op_def_lib._apply_op_helper("Transpose", name, new { x, perm }); return _op.outputs[0]; @@ -200,6 +213,30 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, + int begin_mask = 0, + int end_mask = 0, + int ellipsis_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new + { + input, + begin, + end, + strides, + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + }); + + return _op.outputs[0]; + } + public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) { var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index af432f15..5d037696 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -187,6 +187,61 @@ namespace Tensorflow } } + public Tensor this[int slice_spec] + { + get + { + var slice_spec_s = new int[] { slice_spec }; + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach(var s in slice_spec_s) + { + { + begin.Add(s); + end.Add(s + 1); + strides.Add(1); + shrink_axis_mask |= (1 << index); + } + + index += 1; + } + + return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if(begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); + } + + throw new NotImplementedException(""); + }); + } + + } + public override string ToString() { if(NDims == 0)