Browse Source

fix unassigned shape #406

tags/v0.12
Oceania2018 6 years ago
parent
commit
d857df1216
4 changed files with 42 additions and 24 deletions
  1. +4
    -2
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  2. +19
    -17
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  3. +12
    -5
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  4. +7
    -0
      test/TensorFlowNET.UnitTest/TensorShapeTest.cs

+ 4
- 2
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -42,7 +42,8 @@ namespace Tensorflow.Gradients
var x = op.inputs[0]; var x = op.inputs[0];
var y = op.inputs[1]; var y = op.inputs[1];
var grad = grads[0]; var grad = grads[0];
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad))
if (grad is Tensor &&
_ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, grad }; return new Tensor[] { grad, grad };


var sx = array_ops.shape(x); var sx = array_ops.shape(x);
@@ -375,7 +376,8 @@ namespace Tensorflow.Gradients
var grad = grads[0]; var grad = grads[0];
var x = op.inputs[0]; var x = op.inputs[0];
var y = op.inputs[1]; var y = op.inputs[1];
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad))
if (grad is Tensor &&
_ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, -grad }; return new Tensor[] { grad, -grad };


var sx = array_ops.shape(x); var sx = array_ops.shape(x);


+ 19
- 17
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -132,10 +132,10 @@ namespace Tensorflow


public int[] _shape_tuple() public int[] _shape_tuple()
{ {
return NDims < 0 ? null : shape;
return rank < 0 ? null : shape;
} }


public TensorShape TensorShape => tensor_util.to_shape(shape);
public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape);


/// <summary> /// <summary>
/// Updates the shape of this tensor. /// Updates the shape of this tensor.
@@ -165,6 +165,7 @@ namespace Tensorflow


/// <summary> /// <summary>
/// number of dimensions <br></br> /// number of dimensions <br></br>
/// -1 Unknown <br></br>
/// 0 Scalar (magnitude only) <br></br> /// 0 Scalar (magnitude only) <br></br>
/// 1 Vector (magnitude and direction) <br></br> /// 1 Vector (magnitude and direction) <br></br>
/// 2 Matrix (table of numbers) <br></br> /// 2 Matrix (table of numbers) <br></br>
@@ -178,11 +179,13 @@ namespace Tensorflow
{ {
if (_handle == IntPtr.Zero) if (_handle == IntPtr.Zero)
{ {
var status = new Status();
var output = _as_tf_output();
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status);
status.Check();
return ndim;
using (var status = new Status())
{
var output = _as_tf_output();
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status);
status.Check();
return ndim;
}
} }


return c_api.TF_NumDims(_handle); return c_api.TF_NumDims(_handle);
@@ -440,16 +443,15 @@ namespace Tensorflow
public override string ToString() public override string ToString()
{ {
// this can throw IndexOutOfRangeException // this can throw IndexOutOfRangeException
//if(NDims == 0)
//{
// switch (dtype)
// {
// case TF_DataType.TF_INT32:
// return Data<int>()[0].ToString();
// }
//}

return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
switch (rank)
{
case -1:
return $"tf.Tensor '{name}' shape=<unknown> dtype={dtype}";
case 0:
return $"tf.Tensor '{name}' shape=() dtype={dtype}";
default:
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
}
} }


/// <summary> /// <summary>


+ 12
- 5
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -24,12 +24,13 @@ namespace Tensorflow
/// <summary> /// <summary>
/// Returns the rank of this shape. /// Returns the rank of this shape.
/// </summary> /// </summary>
public int ndim => shape.NDim;
public int ndim => rank;


private int _rank;
/// <summary> /// <summary>
/// Returns the rank of this shape. /// Returns the rank of this shape.
/// </summary> /// </summary>
public int rank => shape.NDim;
public int rank => _rank > -1 ? shape.NDim : -1;


/// <summary> /// <summary>
/// Returns the size this shape represents. /// Returns the size this shape represents.
@@ -52,6 +53,12 @@ namespace Tensorflow
} }
} }


public TensorShape()
{
_rank = -1;
shape = new Shape();
}

public TensorShape(TensorShapeProto proto) public TensorShape(TensorShapeProto proto)
{ {
if (proto.UnknownRank) return; if (proto.UnknownRank) return;
@@ -77,7 +84,7 @@ namespace Tensorflow
switch (dims.Length) switch (dims.Length)
{ {
case 0: shape = new Shape(new int[0]); break; case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int)dims[0]); break;
case 1: shape = Shape.Vector(dims[0]); break;
case 2: shape = Shape.Matrix(dims[0], dims[1]); break; case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
default: shape = new Shape(dims); break; default: shape = new Shape(dims); break;
} }
@@ -127,7 +134,7 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public bool is_fully_defined() public bool is_fully_defined()
{ {
return dims != null && dims.Count(x => x < 1) == 0;
return rank > -1 && dims != null && dims.Count(x => x < 1) == 0;
} }


public bool is_compatible_with(TensorShape shape2) public bool is_compatible_with(TensorShape shape2)
@@ -204,7 +211,7 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public TensorShape merge_with(TensorShape other) public TensorShape merge_with(TensorShape other)
{ {
if (dims.Length == 0)
if (dims == null)
return other; return other;


var new_dims = new List<int>(); var new_dims = new List<int>();


+ 7
- 0
test/TensorFlowNET.UnitTest/TensorShapeTest.cs View File

@@ -56,5 +56,12 @@ namespace TensorFlowNET.UnitTest
TensorShape shape = (Unknown, 1, 2, 3, Unknown); TensorShape shape = (Unknown, 1, 2, 3, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1); shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
} }

[TestMethod]
public void Case7()
{
TensorShape shape = new TensorShape();
Assert.AreEqual(shape.rank, -1);
}
} }
} }

Loading…
Cancel
Save