using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using System.Collections.Generic;
using System;
using System.Linq;
namespace Tensorflow.Keras.Layers
{
///
/// Layer that reshapes inputs into the given shape.
///
public class Reshape : Layer
{
ReshapeArgs args;
public Reshape(ReshapeArgs args)
: base(args)
{
this.args = args;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var shape_tensor = array_ops.shape(inputs);
var shape = new List { shape_tensor.shape[0] };
shape.AddRange(args.TargetShape.dims);
var result = array_ops.reshape(inputs, shape.ToArray());
if (!tf.Context.executing_eagerly())
result.set_shape(ComputeOutputShape(inputs.shape));
return result;
}
public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
if (input_shape.dims[0] == -1)
{
input_shape = input_shape.dims[0];
var output_shape = input_shape.concatenate(args.TargetShape.dims);
return output_shape;
}
else
throw new NotImplementedException("");
}
}
}