Browse Source

even more bug fixes

pull/571/head
carb0n GitHub 5 years ago
parent
commit
d3f535893e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 13 deletions
  1. +16
    -13
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs

+ 16
- 13
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow
TensorShape image_shape;
try
{
if ( image.shape.NDims == null )
if ( image.shape.ndim == null )
{
image_shape = image.shape.with_rank(3);
} else {
@@ -49,19 +49,22 @@ namespace Tensorflow
{
throw new ValueError("\'image\' must be fully defined.");
}
foreach (int x in image_shape[-3..])
for ( int x = 1; x < 4; x++ )
{
throw new ValueError("inner 3 dims of \'image.shape\' must be > 0: %s" %
image_shape);
if ( image_shape[image_shape.Length - x] == 0 )
{
throw new ValueError(String.Format("inner 3 dims of \'image.shape\' must be > 0: {0}", image_shape));
}
}
if ( !image_shape[-3..].is_fully_defined() )
if ( !image_shape.Skip(image_shape.Length - 3).ToArray().is_fully_defined())
{
var temp_array = array_ops.shape(image);
return new Operation[] {
check_ops.assert_positive(
array_ops.shape(image)[-3..],
positive = check_ops.assert_positive(
temp_array.Skip(temp_array.Length -3).ToArray(),
new {@"inner 3 dims of 'image.shape'
must be > 0."}),
check_ops.assert_greater_equal(
greater_equal = check_ops.assert_greater_equal(
array_ops.rank(image),
ops.convert_to_tensor(3),
message: "'image' must be at least three-dimensional.")
@@ -76,7 +79,7 @@ namespace Tensorflow
TensorShape image_shape = image.shape;
if (image_shape == tensor_shape.unknown_shape())
{
result.set_shape(new { null, null, null });
result.set_shape(new TensorShape { null, null, null });
} else {
result.set_shape(image_shape);
}
@@ -103,7 +106,7 @@ namespace Tensorflow
image = ops.convert_to_tensor(image, name: "image");
image = _AssertAtLeast3DImage(image);
Tensor shape = image.shape;
if ( shape.NDims == 3 || shape.NDims == null )
if ( shape.ndim == 3 || shape.ndim == null )
{
var uniform_random = random_ops.random_uniform(new {}, 0, 1.0, seed: seed);
var mirror_cond = math_ops.less(uniform_random, .5);
@@ -114,7 +117,7 @@ namespace Tensorflow
name: scope
);
return fix_image_flip_shape(image, result);
} else if ( shape.NDims == 4 )
} else if ( shape.ndim == 4 )
{
var batch_size = array_ops.shape(image);
var uniform_random = random_ops.random_uniform(batch_size[0],
@@ -146,10 +149,10 @@ namespace Tensorflow
image = ops.convert_to_tensor(image, name: "image");
image = _AssertAtLeast3DImage(image);
Tensor shape = image.shape;
if ( shape.NDims == 3 || shape.NDims == null )
if ( shape.ndim == 3 || shape.ndim == null )
{
return fix_image_flip_shape(image, array_ops.reverse(image, new { flip_index }));
} else if ( shape.NDims == 4 )
} else if ( shape.ndim == 4 )
{
return array_ops.reverse(image, new { flip_index + 1 });
} else


Loading…
Cancel
Save