Browse Source

TensorShape: Added implicit conversions for object type.

tags/v0.12
Eli Belash 6 years ago
parent
commit
420e195aca
2 changed files with 76 additions and 0 deletions
  1. +16
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  2. +60
    -0
      test/TensorFlowNET.UnitTest/TensorShapeTest.cs

+ 16
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -254,6 +254,22 @@ namespace Tensorflow

public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
public static implicit operator TensorShape(int?[] dims) => new TensorShape(dims);
public static implicit operator TensorShape(int? dim) => new TensorShape(dim);
public static implicit operator TensorShape((object, object) dims) => new TensorShape(dims.Item1, dims.Item2);
public static implicit operator TensorShape((object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
public static implicit operator TensorShape((object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
public static implicit operator TensorShape((object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
public static implicit operator TensorShape((object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
public static implicit operator TensorShape((object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
public static implicit operator TensorShape((object, object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);

}
}

+ 60
- 0
test/TensorFlowNET.UnitTest/TensorShapeTest.cs View File

@@ -0,0 +1,60 @@
using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class TensorShapeTest
{
[TestMethod]
public void Case1()
{
int? a = 2;
int? b = 3;
var dims = new object[] {(int?) None, a, b};
new TensorShape(dims).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case2()
{
int? a = 2;
int? b = 3;
var dims = new object[] {(int?) None, a, b};
new TensorShape(new object[] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case3()
{
int? a = 2;
int? b = null;
var dims = new object[] {(int?) None, a, b};
new TensorShape(new object[] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1);
}

[TestMethod]
public void Case4()
{
TensorShape shape = (None, None);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, -1);
}

[TestMethod]
public void Case5()
{
TensorShape shape = (1, None, 3);
shape.GetPrivate<Shape>("shape").Should().BeShaped(1, -1, 3);
}

[TestMethod]
public void Case6()
{
TensorShape shape = (None, 1, 2, 3, None);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
}
}
}

Loading…
Cancel
Save