diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
new file mode 100644
index 00000000..399578cf
--- /dev/null
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -0,0 +1,146 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Operations;
+using static Tensorflow.Python;
+
+namespace Tensorflow.Gradients
+{
+ ///
+ /// tensorflow\python\ops\array_grad.py
+ ///
+ public class array_grad
+ {
+ public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1);
+ }
+
+ ///
+ /// Gradient for concat op.
+ ///
+ /// An operation.
+ ///
+ /// `Tensor` or `IndexedSlices` representing the gradients with respect
+ /// to each output of the op.
+ ///
+ /// An integer index of the first value in the op.inputs.
+ /// An integer index of the last value in the op.inputs.
+ /// An interger index of concat_dim or axis parameter in op.inputs.
+ ///
+ /// Tensors representing the partial gradients with respect to each input
+ /// of the op.
+ ///
+ private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_value_index, int end_value_index, int dim_index)
+ {
+ // Degenerate concatenation, just return grad.
+ if (len(op.inputs) == 2)
+ return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad };
+
+ var concat_dim = op.inputs[dim_index];
+ var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray();
+
+ var out_grads = new List();
+ if (constant_op.is_constant(concat_dim))
+ {
+ /*If concat_dim is a constant defined in a different context,
+ then we duplicate it in the current context to avoid passing it
+ through an Enter node.
+ This is a small optimization in general, but it is required when
+ compiling with XLA, as XLA needs the concat input to be folded into a
+ constant.*/
+ var grad_context = control_flow_util.GetOutputContext(grad.op);
+ var dim_context = control_flow_util.GetOutputContext(concat_dim.op);
+ if (dim_context != grad_context)
+ {
+ var value = tensor_util.constant_value(concat_dim);
+ concat_dim = constant_op.constant(value: value, dtype: concat_dim.dtype);
+ }
+ }
+
+ // Using mod here for convenience since concat_dim is already verified
+ // in concat implementation to be within the allowed [-rank, rank) range.
+ var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]);
+
+ // Get the inputs' tensor shapes
+ var sizes = _ExtractInputShapes(input_values);
+
+ /* The magic number of 16 was found through benchmarking a range of sizes
+ on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
+ cases when switching implementations at N=16, but it is possible that
+ there will be a small number of performance regressions.*/
+ if (len(sizes) > 16)
+ {
+ // extract the size of each input along the concat dimension
+ var slice = array_ops.slice(array_ops.stack(sizes, axis: 1),
+ new Tensor[] { non_neg_concat_dim, tf.constant(0) },
+ new Tensor[] { tf.constant(1), tf.constant(-1) });
+ var squeeze_sizes = array_ops.squeeze(slice);
+ out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
+ }
+ else
+ {
+ var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes);
+ foreach (var (begin, size) in zip(offset, sizes))
+ out_grads.Add(gen_ops.slice(grad, begin, size));
+ }
+
+ return (end_value_index <= dim_index ?
+ out_grads.ToArray().Concat(null) :
+ new Tensor[] { null }.Concat(out_grads)).ToArray();
+ }
+
+ ///
+ /// Extract the shapes of a set of input tensors.
+ ///
+ ///
+ ///
+ private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
+ {
+ var sizes = new Tensor[inputs.Length];
+ bool fully_known = true;
+ for(int i = 0; i < inputs.Length; i++)
+ {
+ var x = inputs[i];
+
+ var input_shape = array_ops.shape(x);
+ if (!(input_shape is Tensor) || input_shape.op.type != "Const")
+ {
+ fully_known = false;
+ break;
+ }
+
+ sizes[i] = input_shape;
+ }
+
+ if (fully_known)
+ return sizes;
+ else
+ return gen_ops.shape_n(inputs);
+ }
+
+
+ public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
+ {
+ return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
+ }
+
+ public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
+ {
+ return new Tensor[] { _ReshapeToInput(op, grads[0]) };
+ }
+
+ private static Tensor _ReshapeToInput(Operation op, Tensor grad)
+ {
+ return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
+ }
+
+ public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
+ {
+ var p = op.inputs[1];
+ return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.py.cs b/src/TensorFlowNET.Core/Gradients/array_grad.py.cs
deleted file mode 100644
index cdd319ea..00000000
--- a/src/TensorFlowNET.Core/Gradients/array_grad.py.cs
+++ /dev/null
@@ -1,30 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
-
-namespace Tensorflow.Gradients
-{
- public class array_grad
- {
- public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
- {
- return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
- }
-
- public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
- {
- return new Tensor[] { _ReshapeToInput(op, grads[0]) };
- }
-
- private static Tensor _ReshapeToInput(Operation op, Tensor grad)
- {
- return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
- }
-
- public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
- {
- var p = op.inputs[1];
- return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
similarity index 99%
rename from src/TensorFlowNET.Core/Gradients/nn_grad.py.cs
rename to src/TensorFlowNET.Core/Gradients/nn_grad.cs
index 0bd03046..9a70b90d 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -6,6 +6,9 @@ using Tensorflow.Operations;
namespace Tensorflow.Gradients
{
+ ///
+ ///
+ ///
public class nn_grad
{
///
diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
index d01d47be..eab9948b 100644
--- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
+++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
@@ -22,6 +22,8 @@ namespace Tensorflow
return math_grad._AddGrad(oper, out_grads);
case "BiasAdd":
return nn_grad._BiasAddGrad(oper, out_grads);
+ case "ConcatV2":
+ return array_grad._ConcatGradV2(oper, out_grads);
case "Exp":
return math_grad._ExpGrad(oper, out_grads);
case "Identity":
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index cabb4e37..0d4ef0b4 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -88,5 +88,15 @@ namespace Tensorflow
return constant_op.constant(s_list, name: name);
}
+
+ public static bool is_constant(ITensorOrOperation tensor_or_op)
+ {
+ if (tensor_or_op is Tensor tensor)
+ return tensor.op.type == "Const";
+ else if (tensor_or_op is Operation op)
+ return op.type == "Const";
+ else
+ throw new ValueError("is_constant");
+ }
}
}