using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; using System.Collections.Generic; using System; using System.Linq; namespace Tensorflow.Keras.Layers { /// /// Layer that reshapes inputs into the given shape. /// public class Reshape : Layer { ReshapeArgs args; public Reshape(ReshapeArgs args) : base(args) { this.args = args; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { var shapes = new List(); shapes.Add(array_ops.shape(inputs)[0]); var dtype = shapes[0].dtype; if (args.TargetShapeObjects != null) // shapes.AddRange(args.TargetShapeObjects); throw new NotImplementedException(""); if (args.TargetShape != null) shapes.AddRange(args.TargetShape.dims.Select(x => constant_op.constant(x, dtype))); var shape = ops.convert_to_tensor(shapes); var result = array_ops.reshape(inputs, shape); if (!tf.Context.executing_eagerly()) result.shape = ComputeOutputShape(inputs.shape); return result; } public override Shape ComputeOutputShape(Shape input_shape) { if (input_shape.dims.Skip(1).Contains(-1)) { throw new NotImplementedException(""); } else { input_shape = new Shape(input_shape.dims[0]); var output_shape = input_shape.concatenate(args.TargetShape.dims); return output_shape; } } } }