Browse Source

tf.boolean_mask #396

tags/v0.12
Oceania2018 6 years ago
parent
commit
f1a3881aff
8 changed files with 84 additions and 3 deletions
  1. +13
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +0
    -0
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  3. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.ops.cs
  4. +5
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +46
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  6. +3
    -2
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  7. +1
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  8. +13
    -0
      test/TensorFlowNET.UnitTest/TensorTest.cs

+ 13
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -39,6 +39,19 @@ namespace Tensorflow
public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null) public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null)
=> gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name);


/// <summary>
/// Apply boolean mask to tensor.
/// </summary>
/// <typeparam name="T1"></typeparam>
/// <typeparam name="T2"></typeparam>
/// <param name="tensor">N-D tensor.</param>
/// <param name="mask">K-D boolean tensor, K <= N and K must be known statically.</param>
/// <param name="name"></param>
/// <param name="axis">A 0-D int Tensor representing the axis in tensor to mask from. </param>
/// <returns>(N-K+1)-dimensional tensor populated by entries in tensor corresponding to True values in mask.</returns>
public Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
=> array_ops.boolean_mask(tensor, mask, name: name, axis: axis);

public Tensor check_numerics(Tensor tensor, string message, string name = null) public Tensor check_numerics(Tensor tensor, string message, string name = null)
=> gen_array_ops.check_numerics(tensor, message, name: name); => gen_array_ops.check_numerics(tensor, message, name: name);




src/TensorFlowNET.Core/APIs/tf.control.cs → src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File


+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.ops.cs View File

@@ -21,6 +21,9 @@ namespace Tensorflow
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name); => state_ops.assign(@ref, value, validate_shape, use_locking, name);


public void device(string device_name)
=> get_default_graph().device(device_name);

public object get_collection(string key, string scope = "") public object get_collection(string key, string scope = "")
=> get_default_graph().get_collection(key, scope: scope); => get_default_graph().get_collection(key, scope: scope);




+ 5
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -288,6 +288,11 @@ namespace Tensorflow
return op; return op;
} }


public void device(string device_name)
{
throw new NotImplementedException("");
}

private void _create_op_helper(Operation op, bool compute_device = true) private void _create_op_helper(Operation op, bool compute_device = true)
{ {
_record_op_seen_by_control_dependencies(op); _record_op_seen_by_control_dependencies(op);


+ 46
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -17,6 +17,8 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding; using static Tensorflow.Binding;
namespace Tensorflow namespace Tensorflow
@@ -66,6 +68,44 @@ namespace Tensorflow
}); });
} }
public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
{
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate
{
var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor");
var mask_tensor = ops.convert_to_tensor(mask, name: "mask");
var shape_mask = mask_tensor.TensorShape;
var ndims_mask = shape_mask.ndim;
var shape_tensor = tensor_tensor.TensorShape;
if (ndims_mask < 1)
throw new ValueError("mask cannot be scalar.");
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 });
var shape1 = concat(new[]
{
shape(tensor_tensor)[$":{axis}"],
tf.expand_dims(leading_size, 0),
shape(tensor_tensor)[$"{axis + ndims_mask}:"]
}, 0);
tensor_tensor = reshape(tensor, shape1);
var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First();
var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray());
var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray());
tensor_tensor.set_shape(s2);
mask_tensor = reshape(mask_tensor, new[] { -1 });
return _apply_mask_1d(tensor_tensor, mask_tensor, axis);
});
}
private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
{
var indices = squeeze(where(mask), axis: new[] { 1 });
return gather(reshaped_tensor, indices, axis: axis);
}
public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{ {
dtype = dtype.as_base_dtype(); dtype = dtype.as_base_dtype();
@@ -336,7 +376,12 @@ namespace Tensorflow
{ {
if( x == null && y == null) if( x == null && y == null)
{ {
throw new NotImplementedException("where");
return tf_with(ops.name_scope(name, "Where", new { condition }), scope =>
{
name = scope;
condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition");
return gen_array_ops.where(condition: condition, name: name);
});
} }
else if(x != null && y != null) else if(x != null && y != null)
{ {


+ 3
- 2
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -274,9 +274,10 @@ namespace Tensorflow
return _op.outputs; return _op.outputs;
} }


public static Tensor where()
public static Tensor where(Tensor condition, string name = null)
{ {
throw new NotImplementedException("where");
var _op = _op_def_lib._apply_op_helper("Where", name, new { input = condition });
return _op.output;
} }


public static Tensor one_hot(Tensor indices, int depth, public static Tensor one_hot(Tensor indices, int depth,


+ 1
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -23,6 +23,7 @@ namespace Tensorflow
{ {
public static class dtypes public static class dtypes
{ {
public static TF_DataType @bool = TF_DataType.TF_BOOL;
public static TF_DataType int8 = TF_DataType.TF_INT8; public static TF_DataType int8 = TF_DataType.TF_INT8;
public static TF_DataType int32 = TF_DataType.TF_INT32; public static TF_DataType int32 = TF_DataType.TF_INT32;
public static TF_DataType int64 = TF_DataType.TF_INT64; public static TF_DataType int64 = TF_DataType.TF_INT64;


+ 13
- 0
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -260,5 +260,18 @@ namespace TensorFlowNET.UnitTest
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
} }
} }

[TestMethod]
public void boolean_mask()
{
var tensor = new[] { 0, 1, 2, 3 };
var mask = np.array(new[] { true, false, true, false });
var masked = tf.boolean_mask(tensor, mask);
using (var sess = tf.Session())
{
var result = sess.run(masked);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, result.ToArray<int>()));
}
}
} }
} }

Loading…
Cancel
Save