Browse Source

tf.layers.flatten: Added support for None dimension

tags/v0.12
Eli Belash 6 years ago
parent
commit
61c4317536
2 changed files with 30 additions and 2 deletions
  1. +21
    -2
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  2. +9
    -0
      test/TensorFlowNET.UnitTest/layers_test/flatten.cs

+ 21
- 2
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System.Collections.Generic;
using System.Linq;
using NumSharp;
using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -182,6 +184,7 @@ namespace Tensorflow
string name = null,
string data_format = "channels_last")
{
var input_shape = inputs.shape;
if (inputs.shape.Length == 0)
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");

@@ -193,9 +196,25 @@ namespace Tensorflow
inputs = array_ops.transpose(inputs, premutation.ToArray());
}

var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1});
ret.set_shape(new int[] {inputs.shape[0], -1});
var ret = array_ops.reshape(inputs, new int[] {input_shape[0], -1});
ret.shape = ret.shape;
//ret.set_shape(compute_output_shape(ret.shape));
return ret;

int[] compute_output_shape(int[] inputshape)
{
if (inputshape == null || inputshape.Length == 0)
inputshape = new int[] {1};

if (inputshape.Skip(1).All(d => d > 0))
{
int[] output_shape = new int[2];
output_shape[0] = inputshape[0];
output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions
return output_shape;
} else
return new int[] {inputshape[0], -1}; //-1 == Binding.None
}
}
}
}


+ 9
- 0
test/TensorFlowNET.UnitTest/layers_test/flatten.cs View File

@@ -36,5 +36,14 @@ namespace TensorFlowNET.UnitTest.layers_test
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
}

[TestMethod]
public void Case4()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, None, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}
}
}

Loading…
Cancel
Save