using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using System.Collections.Generic;
using System;
using System.Linq;
namespace Tensorflow.Keras.Layers
{
///
/// Layer that reshapes inputs into the given shape.
///
public class Reshape : Layer
{
ReshapeArgs args;
public Reshape(ReshapeArgs args)
: base(args)
{
this.args = args;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
var shapes = new List();
shapes.Add(array_ops.shape(inputs)[0]);
var dtype = shapes[0].dtype;
if (args.TargetShapeObjects != null)
// shapes.AddRange(args.TargetShapeObjects);
throw new NotImplementedException("");
if (args.TargetShape != null)
shapes.AddRange(args.TargetShape.dims.Select(x => constant_op.constant(x, dtype)));
var shape = ops.convert_to_tensor(shapes);
var result = array_ops.reshape(inputs, shape);
if (!tf.Context.executing_eagerly())
result.shape = ComputeOutputShape(inputs.shape);
return result;
}
public override Shape ComputeOutputShape(Shape input_shape)
{
if (input_shape.dims.Skip(1).Contains(-1))
{
throw new NotImplementedException("");
}
else
{
input_shape = new Shape(input_shape.dims[0]);
var output_shape = input_shape.concatenate(args.TargetShape.dims);
return output_shape;
}
}
}
}