| @@ -117,6 +117,137 @@ namespace Tensorflow.Gradients | |||||
| }; | }; | ||||
| } | } | ||||
| public static string ellipsis = "..."; | |||||
| [RegisterGradient("Einsum")] | |||||
| public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| // Gradient for Einsum. | |||||
| string equation = (string)op.get_attr("equation"); | |||||
| string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); | |||||
| var input_subs = split_equation[0]; | |||||
| var output_subs = split_equation[1]; | |||||
| if (op.inputs.Length == 1) | |||||
| { | |||||
| var input_shape = array_ops.shape(op.inputs[0]); | |||||
| var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis))); | |||||
| if (reduced_label_set.Count == 0) | |||||
| return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; | |||||
| return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; | |||||
| } | |||||
| string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); | |||||
| var x_subs = split_input_subs[0]; | |||||
| var y_subs = split_input_subs[1]; | |||||
| // Add ellipsis for broadcasted dimensions if any operand does not have it. | |||||
| // This is because the equation "...ij,jk->ik" may be valid if the 0th input's | |||||
| // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid | |||||
| // because only the output subscripts contain ellipsis. | |||||
| if (output_subs.Contains(ellipsis)) | |||||
| { | |||||
| if (!x_subs.Contains(ellipsis)) | |||||
| x_subs += ellipsis; | |||||
| if (!y_subs.Contains(ellipsis)) | |||||
| y_subs += ellipsis; | |||||
| } | |||||
| // Obtain the gradients wrt the inputs x and y, without taking into account | |||||
| // the unbroadcasting. | |||||
| var x = op.inputs[0]; | |||||
| var y = op.inputs[1]; | |||||
| if (grads.GetDataType().is_complex()) | |||||
| { | |||||
| x = math_ops.conj(x); | |||||
| y = math_ops.conj(y); | |||||
| } | |||||
| var x_shape = array_ops.shape(x); | |||||
| var y_shape = array_ops.shape(y); | |||||
| var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); | |||||
| var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); | |||||
| if (!output_subs.Contains(ellipsis)) | |||||
| return new Tensor[] { grad_x, grad_y }; | |||||
| var bx = _GetBcastSubshape(x_subs); | |||||
| int bx_start = bx[0], bx_end = bx[1]; | |||||
| var by = _GetBcastSubshape(y_subs); | |||||
| int by_start = by[0], by_end = by[1]; | |||||
| var x_shape_static = x.shape; | |||||
| var y_shape_static = y.shape; | |||||
| if(x_shape_static.IsFullyDefined && | |||||
| y_shape_static.IsFullyDefined && | |||||
| x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) | |||||
| return new Tensor[] { grad_x, grad_y }; | |||||
| var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], | |||||
| y_shape[string.Format("{0}:{1}", by_start, by_end)]); | |||||
| var rx = r[0]; | |||||
| var ry = r[1]; | |||||
| grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); | |||||
| grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); | |||||
| return new Tensor[] { grad_x, grad_y }; | |||||
| } | |||||
| protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, | |||||
| string input_subs, string other_subs, string output_subs) | |||||
| { | |||||
| var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + "."))); | |||||
| var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||||
| var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); | |||||
| if (reduced_label_set.Count == 0) | |||||
| return grad_reduced; | |||||
| return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); | |||||
| } | |||||
| protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set) | |||||
| { | |||||
| string reduced_subs; | |||||
| Tensor reduced_dims; | |||||
| List<int> reduced_axes; | |||||
| _GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); | |||||
| bool has_repeated_labels = ( | |||||
| new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count < | |||||
| input_subs.Length + output_subs.Length); | |||||
| var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||||
| if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) | |||||
| { | |||||
| var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); | |||||
| return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); | |||||
| } | |||||
| else | |||||
| { | |||||
| var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||||
| var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||||
| var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); | |||||
| return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); | |||||
| } | |||||
| } | |||||
| protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes) | |||||
| { | |||||
| reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); | |||||
| reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); | |||||
| reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); | |||||
| } | |||||
| protected static int _GetAxisFromLabel(string subscripts, char label) | |||||
| { | |||||
| var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); | |||||
| var index = splits[0].IndexOf(label); | |||||
| if (index != -1) return index; | |||||
| if (splits.Length < 2) throw new OutOfRangeError(); | |||||
| index = splits[1].IndexOf(label); | |||||
| if (index != -1) return index; | |||||
| throw new ValueError(); | |||||
| } | |||||
| protected static int[] _GetBcastSubshape(string subscripts) | |||||
| { | |||||
| int start = subscripts.IndexOf(ellipsis); | |||||
| if (start == -1) return new int[] { 0, 0 }; | |||||
| int remaining = subscripts.Length - (start + ellipsis.Length); | |||||
| int end; | |||||
| if (remaining > 0) end = remaining; | |||||
| else throw new Exception(); | |||||
| return new int[] { start, end }; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns grad * exp(x). | /// Returns grad * exp(x). | ||||
| /// </summary> | /// </summary> | ||||