Browse Source

tf.layers.flatten: Added and fixed special unit-test case.

tags/v0.12
Eli Belash 6 years ago
parent
commit
0c32b73fc1
2 changed files with 10 additions and 2 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  2. +9
    -0
      test/TensorFlowNET.UnitTest/layers_test/flatten.cs

+ 1
- 2
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -196,8 +196,7 @@ namespace Tensorflow
inputs = array_ops.transpose(inputs, premutation.ToArray());
}

var ret = array_ops.reshape(inputs, new int[] {input_shape[0], -1});
ret.shape = ret.shape;
var ret = array_ops.reshape(inputs, compute_output_shape(input_shape));
//ret.set_shape(compute_output_shape(ret.shape));
return ret;



+ 9
- 0
test/TensorFlowNET.UnitTest/layers_test/flatten.cs View File

@@ -45,5 +45,14 @@ namespace TensorFlowNET.UnitTest.layers_test
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, None, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}

[TestMethod]
public void Case5()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(None, 4, 3, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}
}
}

Loading…
Cancel
Save