Browse Source

Merge pull request #378 from SciSharp/unknown-dimension

Added support for unknown dimension #377
tags/v0.12
Haiping Chen GitHub 6 years ago
parent
commit
9e237a6733
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 182 additions and 6 deletions
  1. +20
    -2
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  2. +30
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  3. +11
    -0
      src/TensorFlowNET.Core/Binding.cs
  4. +43
    -3
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  5. +60
    -0
      test/TensorFlowNET.UnitTest/TensorShapeTest.cs
  6. +18
    -0
      test/TensorFlowNET.UnitTest/layers_test/flatten.cs

+ 20
- 2
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System.Collections.Generic;
using System.Linq;
using NumSharp;
using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -182,6 +184,7 @@ namespace Tensorflow
string name = null,
string data_format = "channels_last")
{
var input_shape = inputs.shape;
if (inputs.shape.Length == 0)
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");

@@ -193,9 +196,24 @@ namespace Tensorflow
inputs = array_ops.transpose(inputs, premutation.ToArray());
}

var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1});
ret.set_shape(new int[] {inputs.shape[0], -1});
var ret = array_ops.reshape(inputs, compute_output_shape(input_shape));
//ret.set_shape(compute_output_shape(ret.shape));
return ret;

int[] compute_output_shape(int[] inputshape)
{
if (inputshape == null || inputshape.Length == 0)
inputshape = new int[] {1};

if (inputshape.Skip(1).All(d => d > 0))
{
int[] output_shape = new int[2];
output_shape[0] = inputshape[0];
output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions
return output_shape;
} else
return new int[] {inputshape[0], -1}; //-1 == Binding.None
}
}
}
}


+ 30
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -21,6 +21,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using NumSharp.Utilities;

namespace Tensorflow
{
@@ -29,9 +30,37 @@ namespace Tensorflow
/// </summary>
public static partial class Binding
{
private static string _tostring(object obj)
{
switch (obj)
{
case NDArray nd:
return nd.ToString(false);
case Array arr:
if (arr.Rank!=1 || arr.GetType().GetElementType()?.IsArray == true)
arr = Arrays.Flatten(arr);
var objs = toObjectArray(arr);
return $"[{string.Join(", ", objs.Select(_tostring))}]";
default:
return obj?.ToString() ?? "null";
}

object[] toObjectArray(Array arr)
{
var len = arr.LongLength;
var ret = new object[len];
for (long i = 0; i < len; i++)
{
ret[i] = arr.GetValue(i);
}

return ret;
}
}

public static void print(object obj)
{
Console.WriteLine(obj.ToString());
Console.WriteLine(_tostring(obj));
}

public static int len(object a)


+ 11
- 0
src/TensorFlowNET.Core/Binding.cs View File

@@ -7,5 +7,16 @@ namespace Tensorflow
public static partial class Binding
{
public static tensorflow tf { get; } = New<tensorflow>();

/// <summary>
/// Alias to null, similar to python's None.
/// For TensorShape, please use Unknown
/// </summary>
public static readonly object None = null;

/// <summary>
/// Used for TensorShape None
/// </summary>
public static readonly int Unknown = -1;
}
}

+ 43
- 3
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using NumSharp.Utilities;

namespace Tensorflow
{
@@ -32,7 +33,23 @@ namespace Tensorflow
/// <summary>
/// Returns the size this shape represents.
/// </summary>
public int size => shape.Size;
public int size
{
get
{
var dims = shape.Dimensions;
var computed = 1;
for (int i = 0; i < dims.Length; i++)
{
var val = dims[i];
if (val <= 0)
continue;
computed *= val;
}

return computed;
}
}

public TensorShape(TensorShapeProto proto)
{
@@ -59,12 +76,30 @@ namespace Tensorflow
switch (dims.Length)
{
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int) dims[0]); break;
case 1: shape = Shape.Vector((int)dims[0]); break;
case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
default: shape = new Shape(dims); break;
}
}

public TensorShape(int[][] dims)
{
if(dims.Length == 1)
{
switch (dims[0].Length)
{
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int)dims[0][0]); break;
case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break;
default: shape = new Shape(dims[0]); break;
}
}
else
{
throw new NotImplementedException("TensorShape int[][] dims");
}
}

/// <summary>
///
/// </summary>
@@ -188,6 +223,11 @@ namespace Tensorflow

public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);

public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
}
}

+ 60
- 0
test/TensorFlowNET.UnitTest/TensorShapeTest.cs View File

@@ -0,0 +1,60 @@
using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class TensorShapeTest
{
[TestMethod]
public void Case1()
{
int a = 2;
int b = 3;
var dims = new [] { Unknown, a, b};
new TensorShape(dims).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case2()
{
int a = 2;
int b = 3;
var dims = new[] { Unknown, a, b};
new TensorShape(new [] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case3()
{
int a = 2;
int b = Unknown;
var dims = new [] { Unknown, a, b};
new TensorShape(new [] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1);
}

[TestMethod]
public void Case4()
{
TensorShape shape = (Unknown, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, -1);
}

[TestMethod]
public void Case5()
{
TensorShape shape = (1, Unknown, 3);
shape.GetPrivate<Shape>("shape").Should().BeShaped(1, -1, 3);
}

[TestMethod]
public void Case6()
{
TensorShape shape = (Unknown, 1, 2, 3, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
}
}
}

+ 18
- 0
test/TensorFlowNET.UnitTest/layers_test/flatten.cs View File

@@ -36,5 +36,23 @@ namespace TensorFlowNET.UnitTest.layers_test
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
}

[TestMethod]
public void Case4()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}

[TestMethod]
public void Case5()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}
}
}

Loading…
Cancel
Save