Browse Source

fix: fix the bug of boolean_mask

tags/v0.150.0-BERT-Model
Wanglongzhi2001 2 years ago
parent
commit
4e42d7f3a8
4 changed files with 16 additions and 10 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  2. +9
    -4
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  4. +4
    -3
      test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs

+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -428,9 +428,9 @@ namespace Tensorflow.Operations
return x; return x;


var x_rank = array_ops.rank(x); var x_rank = array_ops.rank(x);
var con1 = new object[]
var con1 = new Tensor[]
{ {
new []{1, 0 },
new Tensor(new int[]{0, 2}),
math_ops.range(2, x_rank) math_ops.range(2, x_rank)
}; };
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0)); var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));


+ 9
- 4
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -166,6 +166,11 @@ namespace Tensorflow
throw new ValueError("mask cannot be scalar."); throw new ValueError("mask cannot be scalar.");


var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 })); var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 }));
if (leading_size.rank == 0)
{
leading_size = expand_dims(leading_size, 0);
}

var shape1 = concat(new[] var shape1 = concat(new[]
{ {
shape(tensor_tensor)[$":{axis}"], shape(tensor_tensor)[$":{axis}"],
@@ -185,7 +190,7 @@ namespace Tensorflow


private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
{ {
var indices = squeeze(where(mask), axis: new[] { 1 });
var indices = squeeze(where_v2(mask), axis: new[] { 1 });
return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis)); return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis));
} }


@@ -940,12 +945,12 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
{ {
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
return gen_array_ops.concat_v2(values, axis, name: name);
} }


public static Tensor concat(object[] values, int axis, string name = "concat")
public static Tensor concat(Tensor[] values, Axis axis, string name = "concat")
{ {
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
return gen_array_ops.concat_v2(values, axis, name: name);
} }


/// <summary> /// <summary>


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -287,7 +287,7 @@ namespace Tensorflow
new[] { math_ops.subtract(rank, 1) }, new[] { math_ops.subtract(rank, 1) },
new[] { constant_op.constant(1) }); new[] { constant_op.constant(1) });


var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0);
var output = array_ops.reshape(logits, ops); var output = array_ops.reshape(logits, ops);


// Set output shape if known. // Set output shape if known.


+ 4
- 3
test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs View File

@@ -3,6 +3,7 @@ using Tensorflow.NumPy;
using System; using System;
using System.Linq; using System.Linq;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow;


namespace TensorFlowNET.UnitTest.Basics namespace TensorFlowNET.UnitTest.Basics
{ {
@@ -60,14 +61,14 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
} }


[TestMethod, Ignore]
[TestMethod]
public void boolean_mask() public void boolean_mask()
{ {
if (!tf.executing_eagerly())
tf.enable_eager_execution();
var tensor = new[] { 0, 1, 2, 3 }; var tensor = new[] { 0, 1, 2, 3 };
var mask = np.array(new[] { true, false, true, false }); var mask = np.array(new[] { true, false, true, false });
var masked = tf.boolean_mask(tensor, mask); var masked = tf.boolean_mask(tensor, mask);
var sess = tf.Session();
var result = sess.run(masked);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
} }
} }

Loading…
Cancel
Save