diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 91b59842..a0fb3c72 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -28,7 +28,7 @@ namespace Tensorflow
///
///
///
- public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false)
+ public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
=> array_ops.transpose(a, perm, name, conjugate);
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1)
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.py.cs b/src/TensorFlowNET.Core/Gradients/array_grad.py.cs
index 64e2a9ec..cdd319ea 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.py.cs
@@ -10,5 +10,21 @@ namespace Tensorflow.Gradients
{
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}
+
+ public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
+ {
+ return new Tensor[] { _ReshapeToInput(op, grads[0]) };
+ }
+
+ private static Tensor _ReshapeToInput(Operation op, Tensor grad)
+ {
+ return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
+ }
+
+ public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
+ {
+ var p = op.inputs[1];
+ return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
new file mode 100644
index 00000000..5e27f38c
--- /dev/null
+++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
@@ -0,0 +1,22 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Gradients
+{
+ public class control_flow_grad
+ {
+ public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ var _ = grads[1];
+ var input_op = op.inputs[0].op;
+ var graph = ops.get_default_graph();
+ var op_ctxt = control_flow_util.GetOutputContext(input_op);
+ var pred = op_ctxt.pred;
+
+ var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
+ return new Tensor[] { results.Item1, results.Item2 };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs
index 9f5e2391..a4840fd7 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs
@@ -93,7 +93,7 @@ namespace Tensorflow.Gradients
var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64),
array_ops.size(in_shape) - 1);
- var outerdim = array_ops.shape(ind_2d);
+ var outerdim = array_ops.shape(ind_2d)[0];
// Compute linear indices(flattened to 1D).
var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64);
@@ -102,7 +102,17 @@ namespace Tensorflow.Gradients
var cast2 = math_ops.cast(dim2, TF_DataType.TF_INT32);
var ind = array_ops.reshape(ind_2d + cast2, new int[] { -1 });
- throw new NotImplementedException("nn_grad._TopKGrad");
+ // Substitute grad to appropriate locations and fill the rest with zeros,
+ // finally reshaping it to the original input shape.
+ var scatter = gen_array_ops.scatter_nd(array_ops.expand_dims(ind, -1),
+ array_ops.reshape(grad, new int[] { -1 }),
+ new Tensor[] { math_ops.reduce_prod(in_shape) });
+
+ return new Tensor[]
+ {
+ array_ops.reshape(scatter, in_shape),
+ array_ops.zeros(new int[0], dtype: TF_DataType.TF_INT32)
+ };
}
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
index a19e0db2..9601077b 100644
--- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
+++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
@@ -26,6 +26,8 @@ namespace Tensorflow
return math_grad._IdGrad(oper, out_grads);
case "MatMul":
return math_grad._MatMulGrad(oper, out_grads);
+ case "Merge":
+ return control_flow_grad._MergeGrad(oper, out_grads);
case "Mul":
return math_grad._MulGrad(oper, out_grads);
case "Mean":
@@ -42,8 +44,12 @@ namespace Tensorflow
return array_grad._ReshapeGrad(oper, out_grads);
case "Relu":
return nn_grad._ReluGrad(oper, out_grads);
+ case "Squeeze":
+ return array_grad._SqueezeGrad(oper, out_grads);
case "SoftmaxCrossEntropyWithLogits":
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
+ case "Transpose":
+ return array_grad._TransposeGrad(oper, out_grads);
case "TopK":
case "TopKV2":
return nn_grad._TopKGrad(oper, out_grads);
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index 23799892..9516c42f 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -10,22 +10,28 @@ namespace Tensorflow.Operations
public class CondContext : ControlFlowContext
{
private string _name;
+
///
/// The boolean tensor for the cond predicate
///
private Tensor _pred;
+ public Tensor pred => _pred;
+
///
/// The predicate tensor in this branch
///
private Tensor _pivot;
+
///
/// 0 or 1 representing this branch
///
private int _branch;
+
///
///
///
private List _values = new List();
+
private Dictionary _external_values = new Dictionary();
///
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 74078e27..654e1b81 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -32,5 +32,10 @@ namespace Tensorflow
{
_control_flow_context = ctx;
}
+
+ public CondContext _get_control_flow_context()
+ {
+ return _control_flow_context;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 39df4c37..dc793fea 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -7,7 +7,8 @@ namespace Tensorflow
{
public class array_ops : Python
{
- public static Tensor placeholder_with_default(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name);
+ public static Tensor placeholder_with_default(T input, int[] shape, string name = null)
+ => gen_array_ops.placeholder_with_default(input, shape, name);
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
@@ -111,14 +112,14 @@ namespace Tensorflow
});
}
- public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => expand_dims_v2(input, axis, name);
+ public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
+ => expand_dims_v2(input, axis, name);
- private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name);
+ private static Tensor expand_dims_v2(Tensor input, int axis, string name = null)
+ => gen_array_ops.expand_dims(input, axis, name);
public static Tensor rank(Tensor input, string name = null)
- {
- return math_ops.rank_internal(input, name, optimize: true);
- }
+ => math_ops.rank_internal(input, name, optimize: true);
///
/// Creates a tensor with all elements set to 1.
@@ -132,9 +133,7 @@ namespace Tensorflow
=> ones_like_impl(tensor, dtype, name, optimize);
public static Tensor reshape(T1 tensor, T2 shape, string name = null)
- {
- return gen_array_ops.reshape(tensor, shape, null);
- }
+ => gen_array_ops.reshape(tensor, shape, null);
private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
@@ -239,14 +238,10 @@ namespace Tensorflow
///
/// A `Tensor` of type `out_type`.
public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32)
- {
- return shape_internal(input, name, optimize: true, out_type: out_type);
- }
+ => shape_internal(input, name, optimize: true, out_type: out_type);
public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
- {
- return size_internal(input, name, optimize: optimize, out_type: out_type);
- }
+ => size_internal(input, name, optimize: optimize, out_type: out_type);
private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
{
@@ -323,8 +318,46 @@ namespace Tensorflow
///
///
public static Tensor stop_gradient(Tensor input, string name = null)
+ => gen_array_ops.stop_gradient(input, name);
+
+ ///
+ /// Extracts a strided slice of a tensor (generalized python array indexing).
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end,
+ Tensor strides = null,
+ int begin_mask = 0,
+ int end_mask = 0,
+ int ellipsis_mask = 0,
+ int new_axis_mask = 0,
+ int shrink_axis_mask = 0,
+ string name = null)
{
- return gen_array_ops.stop_gradient(input, name);
+ var op = gen_array_ops.strided_slice(
+ input: input_,
+ begin: begin,
+ end: end,
+ strides: strides,
+ begin_mask: begin_mask,
+ end_mask: end_mask,
+ ellipsis_mask: ellipsis_mask,
+ new_axis_mask: new_axis_mask,
+ shrink_axis_mask: shrink_axis_mask,
+ name: name);
+
+ string parent_name = name;
+
+ return op;
}
///
@@ -345,14 +378,14 @@ namespace Tensorflow
/// Contains the same data as `input`, but has one or more dimensions of
/// size 1 removed.
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int[] squeeze_dims = null)
- {
- return gen_array_ops.squeeze(input, axis, name);
- }
+ => gen_array_ops.squeeze(input, axis, name);
public static Tensor identity(Tensor input, string name = null)
- {
- return gen_array_ops.identity(input, name);
- }
+ => gen_array_ops.identity(input, name);
+
+ public static Tensor invert_permutation(Tensor x, string name = null)
+ => gen_array_ops.invert_permutation(x, name: name);
+
///
/// Computes the shape of a broadcast given symbolic shapes.
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of
@@ -368,26 +401,19 @@ namespace Tensorflow
/// A rank 1 integer `Tensor`, representing the shape of y.
/// A rank 1 integer `Tensor` representing the broadcasted shape.
public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y)
- {
- return gen_array_ops.broadcast_args(shape_x, shape_y);
- }
+ => gen_array_ops.broadcast_args(shape_x, shape_y);
public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
- {
- return Framework.common_shapes.broadcast_shape(shape_x, shape_y);
- }
+ => Framework.common_shapes.broadcast_shape(shape_x, shape_y);
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
- {
- return gen_array_ops.gather_v2(@params, indices, axis, name: name);
- }
+ => gen_array_ops.gather_v2(@params, indices, axis, name: name);
- public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false)
+ public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
{
return with(ops.name_scope(name, "transpose", new { a }), scope =>
{
- name = scope;
- return gen_array_ops.transpose(a, perm, name);
+ return gen_array_ops.transpose(a, perm, name: scope);
});
}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index 4ea21ee6..03773c9f 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -138,6 +138,22 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name);
}
+ ///
+ /// Forwards `data` to an output determined by `pred`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
+ {
+ data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
+
+ ops.colocate_with(data, ignore_existing: true);
+
+ return @switch(data, pred, name: name);
+ }
+
public static Tensor[] cond(Tensor pred,
Func true_fn = null,
Func false_fn = null,
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
index 1b8a304c..146b681c 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -25,5 +26,12 @@ namespace Tensorflow
{
return op.type == "Switch" || op.type == "RefSwitch";
}
+
+ public static CondContext GetOutputContext(Operation op)
+ {
+ var ctxt = op._get_control_flow_context();
+
+ return ctxt;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index ae6e1f09..1d193a8a 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -74,6 +74,13 @@ namespace Tensorflow
return _op.outputs[0];
}
+ public static Tensor invert_permutation(Tensor x, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("InvertPermutation", name, new { x });
+
+ return _op.outputs[0];
+ }
+
public static Tensor log(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Log", name: name, args: new { x });
@@ -163,6 +170,12 @@ namespace Tensorflow
return _op.outputs[0];
}
+ public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("ScatterNd", name, new { indices, updates, shape });
+ return _op.outputs[0];
+ }
+
public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type });
@@ -181,7 +194,7 @@ namespace Tensorflow
return _op.outputs[0];
}
- public static Tensor transpose(Tensor x, int[] perm, string name = null)
+ public static Tensor transpose(T1 x, T2 perm, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Transpose", name, new { x, perm });
return _op.outputs[0];
@@ -200,6 +213,30 @@ namespace Tensorflow
return _op.outputs[0];
}
+ public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides,
+ int begin_mask = 0,
+ int end_mask = 0,
+ int ellipsis_mask = 0,
+ int new_axis_mask = 0,
+ int shrink_axis_mask = 0,
+ string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new
+ {
+ input,
+ begin,
+ end,
+ strides,
+ begin_mask,
+ end_mask,
+ ellipsis_mask,
+ new_axis_mask,
+ shrink_axis_mask
+ });
+
+ return _op.outputs[0];
+ }
+
public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size });
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index af432f15..5d037696 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -187,6 +187,61 @@ namespace Tensorflow
}
}
+ public Tensor this[int slice_spec]
+ {
+ get
+ {
+ var slice_spec_s = new int[] { slice_spec };
+ var begin = new List();
+ var end = new List();
+ var strides = new List();
+
+ var index = 0;
+ var (new_axis_mask, shrink_axis_mask) = (0, 0);
+ var (begin_mask, end_mask) = (0, 0);
+ var ellipsis_mask = 0;
+
+ foreach(var s in slice_spec_s)
+ {
+ {
+ begin.Add(s);
+ end.Add(s + 1);
+ strides.Add(1);
+ shrink_axis_mask |= (1 << index);
+ }
+
+ index += 1;
+ }
+
+ return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
+ {
+ string name = scope;
+ if(begin != null)
+ {
+ var (packed_begin, packed_end, packed_strides) =
+ (array_ops.stack(begin.ToArray()),
+ array_ops.stack(end.ToArray()),
+ array_ops.stack(strides.ToArray()));
+
+ return gen_array_ops.strided_slice(
+ this,
+ packed_begin,
+ packed_end,
+ packed_strides,
+ begin_mask: begin_mask,
+ end_mask: end_mask,
+ shrink_axis_mask: shrink_axis_mask,
+ new_axis_mask: new_axis_mask,
+ ellipsis_mask: ellipsis_mask,
+ name: name);
+ }
+
+ throw new NotImplementedException("");
+ });
+ }
+
+ }
+
public override string ToString()
{
if(NDims == 0)