Browse Source

fix unassigned shape #406

tags/v0.12
Oceania2018 6 years ago
parent
commit
d857df1216
4 changed files with 42 additions and 24 deletions
  1. +4
    -2
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +19
    -17
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  3. +12
    -5
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  4. +7
    -0
      test/TensorFlowNET.UnitTest/TensorShapeTest.cs

+ 4
- 2
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -42,7 +42,8 @@ namespace Tensorflow.Gradients
var x = op.inputs[0];
var y = op.inputs[1];
var grad = grads[0];
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad))
if (grad is Tensor &&
_ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, grad };

var sx = array_ops.shape(x);
@@ -375,7 +376,8 @@ namespace Tensorflow.Gradients
var grad = grads[0];
var x = op.inputs[0];
var y = op.inputs[1];
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad))
if (grad is Tensor &&
_ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, -grad };

var sx = array_ops.shape(x);


+ 19
- 17
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -132,10 +132,10 @@ namespace Tensorflow

public int[] _shape_tuple()
{
return NDims < 0 ? null : shape;
return rank < 0 ? null : shape;
}

public TensorShape TensorShape => tensor_util.to_shape(shape);
public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape);

/// <summary>
/// Updates the shape of this tensor.
@@ -165,6 +165,7 @@ namespace Tensorflow

/// <summary>
/// number of dimensions <br></br>
/// -1 Unknown <br></br>
/// 0 Scalar (magnitude only) <br></br>
/// 1 Vector (magnitude and direction) <br></br>
/// 2 Matrix (table of numbers) <br></br>
@@ -178,11 +179,13 @@ namespace Tensorflow
{
if (_handle == IntPtr.Zero)
{
var status = new Status();
var output = _as_tf_output();
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status);
status.Check();
return ndim;
using (var status = new Status())
{
var output = _as_tf_output();
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status);
status.Check();
return ndim;
}
}

return c_api.TF_NumDims(_handle);
@@ -440,16 +443,15 @@ namespace Tensorflow
public override string ToString()
{
// this can throw IndexOutOfRangeException
//if(NDims == 0)
//{
// switch (dtype)
// {
// case TF_DataType.TF_INT32:
// return Data<int>()[0].ToString();
// }
//}

return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
switch (rank)
{
case -1:
return $"tf.Tensor '{name}' shape=<unknown> dtype={dtype}";
case 0:
return $"tf.Tensor '{name}' shape=() dtype={dtype}";
default:
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
}
}

/// <summary>


+ 12
- 5
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -24,12 +24,13 @@ namespace Tensorflow
/// <summary>
/// Returns the rank of this shape.
/// </summary>
public int ndim => shape.NDim;
public int ndim => rank;

private int _rank;
/// <summary>
/// Returns the rank of this shape.
/// </summary>
public int rank => shape.NDim;
public int rank => _rank > -1 ? shape.NDim : -1;

/// <summary>
/// Returns the size this shape represents.
@@ -52,6 +53,12 @@ namespace Tensorflow
}
}

public TensorShape()
{
_rank = -1;
shape = new Shape();
}

public TensorShape(TensorShapeProto proto)
{
if (proto.UnknownRank) return;
@@ -77,7 +84,7 @@ namespace Tensorflow
switch (dims.Length)
{
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int)dims[0]); break;
case 1: shape = Shape.Vector(dims[0]); break;
case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
default: shape = new Shape(dims); break;
}
@@ -127,7 +134,7 @@ namespace Tensorflow
/// <returns></returns>
public bool is_fully_defined()
{
return dims != null && dims.Count(x => x < 1) == 0;
return rank > -1 && dims != null && dims.Count(x => x < 1) == 0;
}

public bool is_compatible_with(TensorShape shape2)
@@ -204,7 +211,7 @@ namespace Tensorflow
/// <returns></returns>
public TensorShape merge_with(TensorShape other)
{
if (dims.Length == 0)
if (dims == null)
return other;

var new_dims = new List<int>();


+ 7
- 0
test/TensorFlowNET.UnitTest/TensorShapeTest.cs View File

@@ -56,5 +56,12 @@ namespace TensorFlowNET.UnitTest
TensorShape shape = (Unknown, 1, 2, 3, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
}

[TestMethod]
public void Case7()
{
TensorShape shape = new TensorShape();
Assert.AreEqual(shape.rank, -1);
}
}
}

Loading…
Cancel
Save