diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 13258f79..382b0002 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -1,35 +1,84 @@
using NumSharp;
using System;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
+using System.Runtime.CompilerServices;
namespace Tensorflow
{
///
- /// Represents the shape of a `Tensor`.
+ /// Represents the shape of a `Tensor`.
///
+ /// https://www.tensorflow.org/api_docs/python/tf/TensorShape
public class TensorShape
{
- private Shape shape;
+ private readonly Shape shape;
+
+ ///
+ /// Returns a list of Dimensions, or None if the shape is unspecified.
+ ///
public int[] dims => shape.Dimensions;
+
+ ///
+ /// Returns the rank of this shape.
+ ///
public int ndim => shape.NDim;
+
+ ///
+ /// Returns the rank of this shape.
+ ///
+ public int rank => shape.NDim;
+
+ ///
+ /// Returns the size this shape represents.
+ ///
public int size => shape.Size;
public TensorShape(TensorShapeProto proto)
{
if (proto.UnknownRank) return;
+ switch (proto.Dim.Count)
+ {
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break;
+ case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break;
+ default:
+ var protodims = proto.Dim;
+ var len = protodims.Count;
+ var dims = new int[len];
+ for (int i = 0; i < len; i++)
+ dims[i] = (int) protodims[i].Size;
+
- shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray());
+ shape = new Shape(dims); break;
+ }
}
public TensorShape(params int[] dims)
{
- shape = new Shape(dims);
+ switch (dims.Length)
+ {
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int) dims[0]); break;
+ case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
+ default: shape = new Shape(dims); break;
+ }
}
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When is not an Index.
+ [SuppressMessage("ReSharper", "PossibleInvalidOperationException")]
public TensorShape this[Slice slice]
{
get
{
+ if (slice.IsIndex == false)
+ throw new ArgumentException("Slice must be an index.");
+
return new TensorShape(dims.Skip(slice.Start.Value)
.Take(slice.Length.Value)
.ToArray());
@@ -37,7 +86,7 @@ namespace Tensorflow
}
///
- /// Returns True iff `self` is fully defined in every dimension.
+ /// Returns True iff `self` is fully defined in every dimension.
///
///
public bool is_fully_defined()
@@ -50,6 +99,7 @@ namespace Tensorflow
throw new NotImplementedException("TensorShape is_compatible_with");
}
+ [SuppressMessage("ReSharper", "ParameterHidesMember")]
public TensorShape with_rank_at_least(int rank)
{
if (rank != ndim)
@@ -59,35 +109,63 @@ namespace Tensorflow
}
///
- /// Returns the concatenation of the dimension in `self` and `other`.
+ /// Returns the concatenation of the dimension in `self` and `other`.
///
///
///
- public TensorShape concatenate(int[] other_)
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public TensorShape concatenate(int[] other)
{
- var other = new TensorShape(other_);
+ return concatenate(new TensorShape(other));
+ }
- if (ndim < 0 || other.ndim < 0)
+ ///
+ /// Returns the concatenation of the dimension in `self` and `other`.
+ ///
+ ///
+ ///
+ public TensorShape concatenate(TensorShape other)
+ {
+ var otherShape = other;
+
+ if (ndim < 0 || otherShape.ndim < 0)
return new TensorShape();
else
{
- var concatenate_dims = new int[ndim + other.ndim];
+ var concatenate_dims = new int[ndim + otherShape.ndim];
for (int i = 0; i < ndim; i++)
concatenate_dims[i] = dims[i];
- for (int i = 0; i < other.ndim; i++)
- concatenate_dims[ndim + i] = other.dims[i];
+ for (int i = 0; i < otherShape.ndim; i++)
+ concatenate_dims[ndim + i] = otherShape.dims[i];
return new TensorShape(concatenate_dims);
}
}
- public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions);
- public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims);
+ public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
+ public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
+
+ public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
- public static implicit operator int[](TensorShape shape) => shape.dims;
+
+ public static explicit operator int(TensorShape shape) => shape.size;
+ public static explicit operator TensorShape(int dim) => new TensorShape(dim);
+
+ public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
+
+ public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
+
+ public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
+
+ public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
+
+ public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
+
}
}