| @@ -42,7 +42,8 @@ namespace Tensorflow.Gradients | |||||
| var x = op.inputs[0]; | var x = op.inputs[0]; | ||||
| var y = op.inputs[1]; | var y = op.inputs[1]; | ||||
| var grad = grads[0]; | var grad = grads[0]; | ||||
| if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||||
| if (grad is Tensor && | |||||
| _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||||
| return new Tensor[] { grad, grad }; | return new Tensor[] { grad, grad }; | ||||
| var sx = array_ops.shape(x); | var sx = array_ops.shape(x); | ||||
| @@ -375,7 +376,8 @@ namespace Tensorflow.Gradients | |||||
| var grad = grads[0]; | var grad = grads[0]; | ||||
| var x = op.inputs[0]; | var x = op.inputs[0]; | ||||
| var y = op.inputs[1]; | var y = op.inputs[1]; | ||||
| if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||||
| if (grad is Tensor && | |||||
| _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||||
| return new Tensor[] { grad, -grad }; | return new Tensor[] { grad, -grad }; | ||||
| var sx = array_ops.shape(x); | var sx = array_ops.shape(x); | ||||
| @@ -132,10 +132,10 @@ namespace Tensorflow | |||||
| public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
| { | { | ||||
| return NDims < 0 ? null : shape; | |||||
| return rank < 0 ? null : shape; | |||||
| } | } | ||||
| public TensorShape TensorShape => tensor_util.to_shape(shape); | |||||
| public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | |||||
| /// <summary> | /// <summary> | ||||
| /// Updates the shape of this tensor. | /// Updates the shape of this tensor. | ||||
| @@ -165,6 +165,7 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// number of dimensions <br></br> | /// number of dimensions <br></br> | ||||
| /// -1 Unknown <br></br> | |||||
| /// 0 Scalar (magnitude only) <br></br> | /// 0 Scalar (magnitude only) <br></br> | ||||
| /// 1 Vector (magnitude and direction) <br></br> | /// 1 Vector (magnitude and direction) <br></br> | ||||
| /// 2 Matrix (table of numbers) <br></br> | /// 2 Matrix (table of numbers) <br></br> | ||||
| @@ -178,11 +179,13 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
| { | { | ||||
| var status = new Status(); | |||||
| var output = _as_tf_output(); | |||||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||||
| status.Check(); | |||||
| return ndim; | |||||
| using (var status = new Status()) | |||||
| { | |||||
| var output = _as_tf_output(); | |||||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||||
| status.Check(); | |||||
| return ndim; | |||||
| } | |||||
| } | } | ||||
| return c_api.TF_NumDims(_handle); | return c_api.TF_NumDims(_handle); | ||||
| @@ -440,16 +443,15 @@ namespace Tensorflow | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| // this can throw IndexOutOfRangeException | // this can throw IndexOutOfRangeException | ||||
| //if(NDims == 0) | |||||
| //{ | |||||
| // switch (dtype) | |||||
| // { | |||||
| // case TF_DataType.TF_INT32: | |||||
| // return Data<int>()[0].ToString(); | |||||
| // } | |||||
| //} | |||||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||||
| switch (rank) | |||||
| { | |||||
| case -1: | |||||
| return $"tf.Tensor '{name}' shape=<unknown> dtype={dtype}"; | |||||
| case 0: | |||||
| return $"tf.Tensor '{name}' shape=() dtype={dtype}"; | |||||
| default: | |||||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -24,12 +24,13 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the rank of this shape. | /// Returns the rank of this shape. | ||||
| /// </summary> | /// </summary> | ||||
| public int ndim => shape.NDim; | |||||
| public int ndim => rank; | |||||
| private int _rank; | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the rank of this shape. | /// Returns the rank of this shape. | ||||
| /// </summary> | /// </summary> | ||||
| public int rank => shape.NDim; | |||||
| public int rank => _rank > -1 ? shape.NDim : -1; | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the size this shape represents. | /// Returns the size this shape represents. | ||||
| @@ -52,6 +53,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public TensorShape() | |||||
| { | |||||
| _rank = -1; | |||||
| shape = new Shape(); | |||||
| } | |||||
| public TensorShape(TensorShapeProto proto) | public TensorShape(TensorShapeProto proto) | ||||
| { | { | ||||
| if (proto.UnknownRank) return; | if (proto.UnknownRank) return; | ||||
| @@ -77,7 +84,7 @@ namespace Tensorflow | |||||
| switch (dims.Length) | switch (dims.Length) | ||||
| { | { | ||||
| case 0: shape = new Shape(new int[0]); break; | case 0: shape = new Shape(new int[0]); break; | ||||
| case 1: shape = Shape.Vector((int)dims[0]); break; | |||||
| case 1: shape = Shape.Vector(dims[0]); break; | |||||
| case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | ||||
| default: shape = new Shape(dims); break; | default: shape = new Shape(dims); break; | ||||
| } | } | ||||
| @@ -127,7 +134,7 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public bool is_fully_defined() | public bool is_fully_defined() | ||||
| { | { | ||||
| return dims != null && dims.Count(x => x < 1) == 0; | |||||
| return rank > -1 && dims != null && dims.Count(x => x < 1) == 0; | |||||
| } | } | ||||
| public bool is_compatible_with(TensorShape shape2) | public bool is_compatible_with(TensorShape shape2) | ||||
| @@ -204,7 +211,7 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public TensorShape merge_with(TensorShape other) | public TensorShape merge_with(TensorShape other) | ||||
| { | { | ||||
| if (dims.Length == 0) | |||||
| if (dims == null) | |||||
| return other; | return other; | ||||
| var new_dims = new List<int>(); | var new_dims = new List<int>(); | ||||
| @@ -56,5 +56,12 @@ namespace TensorFlowNET.UnitTest | |||||
| TensorShape shape = (Unknown, 1, 2, 3, Unknown); | TensorShape shape = (Unknown, 1, 2, 3, Unknown); | ||||
| shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1); | shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Case7() | |||||
| { | |||||
| TensorShape shape = new TensorShape(); | |||||
| Assert.AreEqual(shape.rank, -1); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||