| @@ -0,0 +1,7 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class ReshapeArgs : LayerArgs | |||||
| { | |||||
| public TensorShape TargetShape { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -34,5 +34,16 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| Size = size ?? (2, 2) | Size = size ?? (2, 2) | ||||
| }); | }); | ||||
| /// <summary> | |||||
| /// Layer that reshapes inputs into the given shape. | |||||
| /// </summary> | |||||
| /// <param name="target_shape"></param> | |||||
| /// <returns></returns> | |||||
| public Reshape Reshape(TensorShape target_shape) | |||||
| => new Reshape(new ReshapeArgs | |||||
| { | |||||
| TargetShape = target_shape | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| @@ -372,8 +372,8 @@ namespace Tensorflow.Keras.Layers | |||||
| InputShape = input_shape | InputShape = input_shape | ||||
| }); | }); | ||||
| public Add Add(params Tensor[] inputs) | |||||
| => new Add(new MergeArgs { Inputs = inputs }); | |||||
| public Add Add() | |||||
| => new Add(new MergeArgs { }); | |||||
| public GlobalAveragePooling2D GlobalAveragePooling2D() | public GlobalAveragePooling2D GlobalAveragePooling2D() | ||||
| => new GlobalAveragePooling2D(new Pooling2DArgs { }); | => new GlobalAveragePooling2D(new Pooling2DArgs { }); | ||||
| @@ -0,0 +1,34 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using static Tensorflow.KerasApi; | |||||
| using static Tensorflow.Binding; | |||||
| using System.Collections.Generic; | |||||
| using System; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| /// <summary> | |||||
| /// Layer that reshapes inputs into the given shape. | |||||
| /// </summary> | |||||
| public class Reshape : Layer | |||||
| { | |||||
| ReshapeArgs args; | |||||
| public Reshape(ReshapeArgs args) | |||||
| : base(args) | |||||
| { | |||||
| this.args = args; | |||||
| } | |||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | |||||
| var shape = new List<int> { inputs.shape[0] }; | |||||
| shape.AddRange(args.TargetShape.dims); | |||||
| var result = array_ops.reshape(inputs, shape.ToArray()); | |||||
| if (!tf.Context.executing_eagerly()) | |||||
| // result = result.set_shape(compute_output_shape(inputs.shape)); | |||||
| throw new NotImplementedException(""); | |||||
| return result; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,6 +1,6 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| namespace TensorFlowNET.UnitTest.Keras | namespace TensorFlowNET.UnitTest.Keras | ||||
| @@ -26,5 +26,14 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
| var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x); | var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x); | ||||
| Assert.AreEqual((2, 2, 2, 3), y.shape); | Assert.AreEqual((2, 2, 2, 3), y.shape); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Reshape() | |||||
| { | |||||
| var inputs = tf.zeros((10, 5, 20)); | |||||
| var outputs = keras.layers.LeakyReLU().Apply(inputs); | |||||
| outputs = keras.layers.Reshape((20, 5)).Apply(outputs); | |||||
| Assert.AreEqual((10, 20, 5), outputs.shape); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||