diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 13258f79..382b0002 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -1,35 +1,84 @@ using NumSharp; using System; +using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Runtime.CompilerServices; namespace Tensorflow { /// - /// Represents the shape of a `Tensor`. + /// Represents the shape of a `Tensor`. /// + /// https://www.tensorflow.org/api_docs/python/tf/TensorShape public class TensorShape { - private Shape shape; + private readonly Shape shape; + + /// + /// Returns a list of Dimensions, or None if the shape is unspecified. + /// public int[] dims => shape.Dimensions; + + /// + /// Returns the rank of this shape. + /// public int ndim => shape.NDim; + + /// + /// Returns the rank of this shape. + /// + public int rank => shape.NDim; + + /// + /// Returns the size this shape represents. + /// public int size => shape.Size; public TensorShape(TensorShapeProto proto) { if (proto.UnknownRank) return; + switch (proto.Dim.Count) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; + case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; + default: + var protodims = proto.Dim; + var len = protodims.Count; + var dims = new int[len]; + for (int i = 0; i < len; i++) + dims[i] = (int) protodims[i].Size; + - shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); + shape = new Shape(dims); break; + } } public TensorShape(params int[] dims) { - shape = new Shape(dims); + switch (dims.Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int) dims[0]); break; + case 2: shape = Shape.Matrix(dims[0], dims[1]); break; + default: shape = new Shape(dims); break; + } } + /// + /// + /// + /// + /// + /// When is not an Index. + [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] public TensorShape this[Slice slice] { get { + if (slice.IsIndex == false) + throw new ArgumentException("Slice must be an index."); + return new TensorShape(dims.Skip(slice.Start.Value) .Take(slice.Length.Value) .ToArray()); @@ -37,7 +86,7 @@ namespace Tensorflow } /// - /// Returns True iff `self` is fully defined in every dimension. + /// Returns True iff `self` is fully defined in every dimension. /// /// public bool is_fully_defined() @@ -50,6 +99,7 @@ namespace Tensorflow throw new NotImplementedException("TensorShape is_compatible_with"); } + [SuppressMessage("ReSharper", "ParameterHidesMember")] public TensorShape with_rank_at_least(int rank) { if (rank != ndim) @@ -59,35 +109,63 @@ namespace Tensorflow } /// - /// Returns the concatenation of the dimension in `self` and `other`. + /// Returns the concatenation of the dimension in `self` and `other`. /// /// /// - public TensorShape concatenate(int[] other_) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TensorShape concatenate(int[] other) { - var other = new TensorShape(other_); + return concatenate(new TensorShape(other)); + } - if (ndim < 0 || other.ndim < 0) + /// + /// Returns the concatenation of the dimension in `self` and `other`. + /// + /// + /// + public TensorShape concatenate(TensorShape other) + { + var otherShape = other; + + if (ndim < 0 || otherShape.ndim < 0) return new TensorShape(); else { - var concatenate_dims = new int[ndim + other.ndim]; + var concatenate_dims = new int[ndim + otherShape.ndim]; for (int i = 0; i < ndim; i++) concatenate_dims[i] = dims[i]; - for (int i = 0; i < other.ndim; i++) - concatenate_dims[ndim + i] = other.dims[i]; + for (int i = 0; i < otherShape.ndim; i++) + concatenate_dims[ndim + i] = otherShape.dims[i]; return new TensorShape(concatenate_dims); } } - public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); - public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); + public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); + public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); + + public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); - public static implicit operator int[](TensorShape shape) => shape.dims; + + public static explicit operator int(TensorShape shape) => shape.size; + public static explicit operator TensorShape(int dim) => new TensorShape(dim); + + public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); + + public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); + + public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); + + public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); + } }