From d857df1216608e2ec2d6052163ebef6a250cd5d5 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 30 Sep 2019 00:00:09 -0500 Subject: [PATCH] fix unassigned shape #406 --- src/TensorFlowNET.Core/Gradients/math_grad.cs | 6 ++-- src/TensorFlowNET.Core/Tensors/Tensor.cs | 36 ++++++++++--------- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 17 ++++++--- .../TensorFlowNET.UnitTest/TensorShapeTest.cs | 7 ++++ 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 5146c777..49dcbc45 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -42,7 +42,8 @@ namespace Tensorflow.Gradients var x = op.inputs[0]; var y = op.inputs[1]; var grad = grads[0]; - if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad)) return new Tensor[] { grad, grad }; var sx = array_ops.shape(x); @@ -375,7 +376,8 @@ namespace Tensorflow.Gradients var grad = grads[0]; var x = op.inputs[0]; var y = op.inputs[1]; - if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad)) return new Tensor[] { grad, -grad }; var sx = array_ops.shape(x); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index f980fea9..f3ad2efd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -132,10 +132,10 @@ namespace Tensorflow public int[] _shape_tuple() { - return NDims < 0 ? null : shape; + return rank < 0 ? null : shape; } - public TensorShape TensorShape => tensor_util.to_shape(shape); + public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); /// /// Updates the shape of this tensor. @@ -165,6 +165,7 @@ namespace Tensorflow /// /// number of dimensions

+ /// -1 Unknown

/// 0 Scalar (magnitude only)

/// 1 Vector (magnitude and direction)

/// 2 Matrix (table of numbers)

@@ -178,11 +179,13 @@ namespace Tensorflow { if (_handle == IntPtr.Zero) { - var status = new Status(); - var output = _as_tf_output(); - int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); - status.Check(); - return ndim; + using (var status = new Status()) + { + var output = _as_tf_output(); + int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); + status.Check(); + return ndim; + } } return c_api.TF_NumDims(_handle); @@ -440,16 +443,15 @@ namespace Tensorflow public override string ToString() { // this can throw IndexOutOfRangeException - //if(NDims == 0) - //{ - // switch (dtype) - // { - // case TF_DataType.TF_INT32: - // return Data()[0].ToString(); - // } - //} - - return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; + switch (rank) + { + case -1: + return $"tf.Tensor '{name}' shape= dtype={dtype}"; + case 0: + return $"tf.Tensor '{name}' shape=() dtype={dtype}"; + default: + return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; + } } /// diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index f8417924..1e239d50 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -24,12 +24,13 @@ namespace Tensorflow /// /// Returns the rank of this shape. /// - public int ndim => shape.NDim; + public int ndim => rank; + private int _rank; /// /// Returns the rank of this shape. /// - public int rank => shape.NDim; + public int rank => _rank > -1 ? shape.NDim : -1; /// /// Returns the size this shape represents. @@ -52,6 +53,12 @@ namespace Tensorflow } } + public TensorShape() + { + _rank = -1; + shape = new Shape(); + } + public TensorShape(TensorShapeProto proto) { if (proto.UnknownRank) return; @@ -77,7 +84,7 @@ namespace Tensorflow switch (dims.Length) { case 0: shape = new Shape(new int[0]); break; - case 1: shape = Shape.Vector((int)dims[0]); break; + case 1: shape = Shape.Vector(dims[0]); break; case 2: shape = Shape.Matrix(dims[0], dims[1]); break; default: shape = new Shape(dims); break; } @@ -127,7 +134,7 @@ namespace Tensorflow /// public bool is_fully_defined() { - return dims != null && dims.Count(x => x < 1) == 0; + return rank > -1 && dims != null && dims.Count(x => x < 1) == 0; } public bool is_compatible_with(TensorShape shape2) @@ -204,7 +211,7 @@ namespace Tensorflow /// public TensorShape merge_with(TensorShape other) { - if (dims.Length == 0) + if (dims == null) return other; var new_dims = new List(); diff --git a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs index efa7def3..b7846ce3 100644 --- a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs @@ -56,5 +56,12 @@ namespace TensorFlowNET.UnitTest TensorShape shape = (Unknown, 1, 2, 3, Unknown); shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1); } + + [TestMethod] + public void Case7() + { + TensorShape shape = new TensorShape(); + Assert.AreEqual(shape.rank, -1); + } } } \ No newline at end of file