| @@ -1,35 +1,84 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Represents the shape of a `Tensor`. | |||
| /// Represents the shape of a `Tensor`. | |||
| /// </summary> | |||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks> | |||
| public class TensorShape | |||
| { | |||
| private Shape shape; | |||
| private readonly Shape shape; | |||
| /// <summary> | |||
| /// Returns a list of Dimensions, or None if the shape is unspecified. | |||
| /// </summary> | |||
| public int[] dims => shape.Dimensions; | |||
| /// <summary> | |||
| /// Returns the rank of this shape. | |||
| /// </summary> | |||
| public int ndim => shape.NDim; | |||
| /// <summary> | |||
| /// Returns the rank of this shape. | |||
| /// </summary> | |||
| public int rank => shape.NDim; | |||
| /// <summary> | |||
| /// Returns the size this shape represents. | |||
| /// </summary> | |||
| public int size => shape.Size; | |||
| public TensorShape(TensorShapeProto proto) | |||
| { | |||
| if (proto.UnknownRank) return; | |||
| switch (proto.Dim.Count) | |||
| { | |||
| case 0: shape = new Shape(new int[0]); break; | |||
| case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; | |||
| case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; | |||
| default: | |||
| var protodims = proto.Dim; | |||
| var len = protodims.Count; | |||
| var dims = new int[len]; | |||
| for (int i = 0; i < len; i++) | |||
| dims[i] = (int) protodims[i].Size; | |||
| shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); | |||
| shape = new Shape(dims); break; | |||
| } | |||
| } | |||
| public TensorShape(params int[] dims) | |||
| { | |||
| shape = new Shape(dims); | |||
| switch (dims.Length) | |||
| { | |||
| case 0: shape = new Shape(new int[0]); break; | |||
| case 1: shape = Shape.Vector((int) dims[0]); break; | |||
| case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | |||
| default: shape = new Shape(dims); break; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="slice"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException">When <see cref="Slice"/> is not an Index.</exception> | |||
| [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] | |||
| public TensorShape this[Slice slice] | |||
| { | |||
| get | |||
| { | |||
| if (slice.IsIndex == false) | |||
| throw new ArgumentException("Slice must be an index."); | |||
| return new TensorShape(dims.Skip(slice.Start.Value) | |||
| .Take(slice.Length.Value) | |||
| .ToArray()); | |||
| @@ -37,7 +86,7 @@ namespace Tensorflow | |||
| } | |||
| /// <summary> | |||
| /// Returns True iff `self` is fully defined in every dimension. | |||
| /// Returns True iff `self` is fully defined in every dimension. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public bool is_fully_defined() | |||
| @@ -50,6 +99,7 @@ namespace Tensorflow | |||
| throw new NotImplementedException("TensorShape is_compatible_with"); | |||
| } | |||
| [SuppressMessage("ReSharper", "ParameterHidesMember")] | |||
| public TensorShape with_rank_at_least(int rank) | |||
| { | |||
| if (rank != ndim) | |||
| @@ -59,35 +109,63 @@ namespace Tensorflow | |||
| } | |||
| /// <summary> | |||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||
| /// </summary> | |||
| /// <param name="other"></param> | |||
| /// <returns></returns> | |||
| public TensorShape concatenate(int[] other_) | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public TensorShape concatenate(int[] other) | |||
| { | |||
| var other = new TensorShape(other_); | |||
| return concatenate(new TensorShape(other)); | |||
| } | |||
| if (ndim < 0 || other.ndim < 0) | |||
| /// <summary> | |||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||
| /// </summary> | |||
| /// <param name="other"></param> | |||
| /// <returns></returns> | |||
| public TensorShape concatenate(TensorShape other) | |||
| { | |||
| var otherShape = other; | |||
| if (ndim < 0 || otherShape.ndim < 0) | |||
| return new TensorShape(); | |||
| else | |||
| { | |||
| var concatenate_dims = new int[ndim + other.ndim]; | |||
| var concatenate_dims = new int[ndim + otherShape.ndim]; | |||
| for (int i = 0; i < ndim; i++) | |||
| concatenate_dims[i] = dims[i]; | |||
| for (int i = 0; i < other.ndim; i++) | |||
| concatenate_dims[ndim + i] = other.dims[i]; | |||
| for (int i = 0; i < otherShape.ndim; i++) | |||
| concatenate_dims[ndim + i] = otherShape.dims[i]; | |||
| return new TensorShape(concatenate_dims); | |||
| } | |||
| } | |||
| public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); | |||
| public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); | |||
| public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); | |||
| public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | |||
| public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes | |||
| public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | |||
| public static implicit operator int[](TensorShape shape) => shape.dims; | |||
| public static explicit operator int(TensorShape shape) => shape.size; | |||
| public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||
| public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||
| public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | |||
| public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | |||
| public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | |||
| public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||
| public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||
| public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||
| } | |||
| } | |||