| @@ -1,35 +1,84 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Represents the shape of a `Tensor`. | |||||
| /// Represents the shape of a `Tensor`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks> | |||||
| public class TensorShape | public class TensorShape | ||||
| { | { | ||||
| private Shape shape; | |||||
| private readonly Shape shape; | |||||
| /// <summary> | |||||
| /// Returns a list of Dimensions, or None if the shape is unspecified. | |||||
| /// </summary> | |||||
| public int[] dims => shape.Dimensions; | public int[] dims => shape.Dimensions; | ||||
| /// <summary> | |||||
| /// Returns the rank of this shape. | |||||
| /// </summary> | |||||
| public int ndim => shape.NDim; | public int ndim => shape.NDim; | ||||
| /// <summary> | |||||
| /// Returns the rank of this shape. | |||||
| /// </summary> | |||||
| public int rank => shape.NDim; | |||||
| /// <summary> | |||||
| /// Returns the size this shape represents. | |||||
| /// </summary> | |||||
| public int size => shape.Size; | public int size => shape.Size; | ||||
| public TensorShape(TensorShapeProto proto) | public TensorShape(TensorShapeProto proto) | ||||
| { | { | ||||
| if (proto.UnknownRank) return; | if (proto.UnknownRank) return; | ||||
| switch (proto.Dim.Count) | |||||
| { | |||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; | |||||
| case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; | |||||
| default: | |||||
| var protodims = proto.Dim; | |||||
| var len = protodims.Count; | |||||
| var dims = new int[len]; | |||||
| for (int i = 0; i < len; i++) | |||||
| dims[i] = (int) protodims[i].Size; | |||||
| shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); | |||||
| shape = new Shape(dims); break; | |||||
| } | |||||
| } | } | ||||
| public TensorShape(params int[] dims) | public TensorShape(params int[] dims) | ||||
| { | { | ||||
| shape = new Shape(dims); | |||||
| switch (dims.Length) | |||||
| { | |||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) dims[0]); break; | |||||
| case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | |||||
| default: shape = new Shape(dims); break; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="slice"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <see cref="Slice"/> is not an Index.</exception> | |||||
| [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] | |||||
| public TensorShape this[Slice slice] | public TensorShape this[Slice slice] | ||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if (slice.IsIndex == false) | |||||
| throw new ArgumentException("Slice must be an index."); | |||||
| return new TensorShape(dims.Skip(slice.Start.Value) | return new TensorShape(dims.Skip(slice.Start.Value) | ||||
| .Take(slice.Length.Value) | .Take(slice.Length.Value) | ||||
| .ToArray()); | .ToArray()); | ||||
| @@ -37,7 +86,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns True iff `self` is fully defined in every dimension. | |||||
| /// Returns True iff `self` is fully defined in every dimension. | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public bool is_fully_defined() | public bool is_fully_defined() | ||||
| @@ -50,6 +99,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("TensorShape is_compatible_with"); | throw new NotImplementedException("TensorShape is_compatible_with"); | ||||
| } | } | ||||
| [SuppressMessage("ReSharper", "ParameterHidesMember")] | |||||
| public TensorShape with_rank_at_least(int rank) | public TensorShape with_rank_at_least(int rank) | ||||
| { | { | ||||
| if (rank != ndim) | if (rank != ndim) | ||||
| @@ -59,35 +109,63 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="other"></param> | /// <param name="other"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public TensorShape concatenate(int[] other_) | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public TensorShape concatenate(int[] other) | |||||
| { | { | ||||
| var other = new TensorShape(other_); | |||||
| return concatenate(new TensorShape(other)); | |||||
| } | |||||
| if (ndim < 0 || other.ndim < 0) | |||||
| /// <summary> | |||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// </summary> | |||||
| /// <param name="other"></param> | |||||
| /// <returns></returns> | |||||
| public TensorShape concatenate(TensorShape other) | |||||
| { | |||||
| var otherShape = other; | |||||
| if (ndim < 0 || otherShape.ndim < 0) | |||||
| return new TensorShape(); | return new TensorShape(); | ||||
| else | else | ||||
| { | { | ||||
| var concatenate_dims = new int[ndim + other.ndim]; | |||||
| var concatenate_dims = new int[ndim + otherShape.ndim]; | |||||
| for (int i = 0; i < ndim; i++) | for (int i = 0; i < ndim; i++) | ||||
| concatenate_dims[i] = dims[i]; | concatenate_dims[i] = dims[i]; | ||||
| for (int i = 0; i < other.ndim; i++) | |||||
| concatenate_dims[ndim + i] = other.dims[i]; | |||||
| for (int i = 0; i < otherShape.ndim; i++) | |||||
| concatenate_dims[ndim + i] = otherShape.dims[i]; | |||||
| return new TensorShape(concatenate_dims); | return new TensorShape(concatenate_dims); | ||||
| } | } | ||||
| } | } | ||||
| public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); | |||||
| public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); | |||||
| public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); | |||||
| public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | |||||
| public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes | |||||
| public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | ||||
| public static implicit operator int[](TensorShape shape) => shape.dims; | |||||
| public static explicit operator int(TensorShape shape) => shape.size; | |||||
| public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||||
| public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||||
| public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | ||||
| public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | ||||
| public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | ||||
| public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||||
| public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||||
| } | } | ||||
| } | } | ||||