| @@ -42,7 +42,8 @@ namespace Tensorflow.Gradients | |||
| var x = op.inputs[0]; | |||
| var y = op.inputs[1]; | |||
| var grad = grads[0]; | |||
| if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||
| if (grad is Tensor && | |||
| _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||
| return new Tensor[] { grad, grad }; | |||
| var sx = array_ops.shape(x); | |||
| @@ -375,7 +376,8 @@ namespace Tensorflow.Gradients | |||
| var grad = grads[0]; | |||
| var x = op.inputs[0]; | |||
| var y = op.inputs[1]; | |||
| if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||
| if (grad is Tensor && | |||
| _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||
| return new Tensor[] { grad, -grad }; | |||
| var sx = array_ops.shape(x); | |||
| @@ -132,10 +132,10 @@ namespace Tensorflow | |||
| public int[] _shape_tuple() | |||
| { | |||
| return NDims < 0 ? null : shape; | |||
| return rank < 0 ? null : shape; | |||
| } | |||
| public TensorShape TensorShape => tensor_util.to_shape(shape); | |||
| public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | |||
| /// <summary> | |||
| /// Updates the shape of this tensor. | |||
| @@ -165,6 +165,7 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// number of dimensions <br></br> | |||
| /// -1 Unknown <br></br> | |||
| /// 0 Scalar (magnitude only) <br></br> | |||
| /// 1 Vector (magnitude and direction) <br></br> | |||
| /// 2 Matrix (table of numbers) <br></br> | |||
| @@ -178,11 +179,13 @@ namespace Tensorflow | |||
| { | |||
| if (_handle == IntPtr.Zero) | |||
| { | |||
| var status = new Status(); | |||
| var output = _as_tf_output(); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||
| status.Check(); | |||
| return ndim; | |||
| using (var status = new Status()) | |||
| { | |||
| var output = _as_tf_output(); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||
| status.Check(); | |||
| return ndim; | |||
| } | |||
| } | |||
| return c_api.TF_NumDims(_handle); | |||
| @@ -440,16 +443,15 @@ namespace Tensorflow | |||
| public override string ToString() | |||
| { | |||
| // this can throw IndexOutOfRangeException | |||
| //if(NDims == 0) | |||
| //{ | |||
| // switch (dtype) | |||
| // { | |||
| // case TF_DataType.TF_INT32: | |||
| // return Data<int>()[0].ToString(); | |||
| // } | |||
| //} | |||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||
| switch (rank) | |||
| { | |||
| case -1: | |||
| return $"tf.Tensor '{name}' shape=<unknown> dtype={dtype}"; | |||
| case 0: | |||
| return $"tf.Tensor '{name}' shape=() dtype={dtype}"; | |||
| default: | |||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -24,12 +24,13 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Returns the rank of this shape. | |||
| /// </summary> | |||
| public int ndim => shape.NDim; | |||
| public int ndim => rank; | |||
| private int _rank; | |||
| /// <summary> | |||
| /// Returns the rank of this shape. | |||
| /// </summary> | |||
| public int rank => shape.NDim; | |||
| public int rank => _rank > -1 ? shape.NDim : -1; | |||
| /// <summary> | |||
| /// Returns the size this shape represents. | |||
| @@ -52,6 +53,12 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public TensorShape() | |||
| { | |||
| _rank = -1; | |||
| shape = new Shape(); | |||
| } | |||
| public TensorShape(TensorShapeProto proto) | |||
| { | |||
| if (proto.UnknownRank) return; | |||
| @@ -77,7 +84,7 @@ namespace Tensorflow | |||
| switch (dims.Length) | |||
| { | |||
| case 0: shape = new Shape(new int[0]); break; | |||
| case 1: shape = Shape.Vector((int)dims[0]); break; | |||
| case 1: shape = Shape.Vector(dims[0]); break; | |||
| case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | |||
| default: shape = new Shape(dims); break; | |||
| } | |||
| @@ -127,7 +134,7 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public bool is_fully_defined() | |||
| { | |||
| return dims != null && dims.Count(x => x < 1) == 0; | |||
| return rank > -1 && dims != null && dims.Count(x => x < 1) == 0; | |||
| } | |||
| public bool is_compatible_with(TensorShape shape2) | |||
| @@ -204,7 +211,7 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public TensorShape merge_with(TensorShape other) | |||
| { | |||
| if (dims.Length == 0) | |||
| if (dims == null) | |||
| return other; | |||
| var new_dims = new List<int>(); | |||
| @@ -56,5 +56,12 @@ namespace TensorFlowNET.UnitTest | |||
| TensorShape shape = (Unknown, 1, 2, 3, Unknown); | |||
| shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1); | |||
| } | |||
| [TestMethod] | |||
| public void Case7() | |||
| { | |||
| TensorShape shape = new TensorShape(); | |||
| Assert.AreEqual(shape.rank, -1); | |||
| } | |||
| } | |||
| } | |||