diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 74e9ef10..b7ca5cf9 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -231,6 +231,12 @@ namespace Tensorflow.Gradients return new Tensor[] { x_grad, null }; } + [RegisterGradient("Split")] + public static Tensor[] _SplitGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { null, array_ops.concat(list(grads), op.inputs[0]) }; + } + [RegisterGradient("Squeeze")] public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) {