From f1a3881affbbbb4bc2b34b2090355b6f0947c7e9 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 18 Sep 2019 13:26:38 -0500 Subject: [PATCH] tf.boolean_mask #396 --- src/TensorFlowNET.Core/APIs/tf.array.cs | 13 +++++ .../{tf.control.cs => tf.control_flow.cs} | 0 src/TensorFlowNET.Core/APIs/tf.ops.cs | 3 ++ src/TensorFlowNET.Core/Graphs/Graph.cs | 5 ++ .../Operations/array_ops.py.cs | 47 ++++++++++++++++++- .../Operations/gen_array_ops.cs | 5 +- src/TensorFlowNET.Core/Tensors/dtypes.cs | 1 + test/TensorFlowNET.UnitTest/TensorTest.cs | 13 +++++ 8 files changed, 84 insertions(+), 3 deletions(-) rename src/TensorFlowNET.Core/APIs/{tf.control.cs => tf.control_flow.cs} (100%) diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index bef72417..fbc01a8b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -39,6 +39,19 @@ namespace Tensorflow public Tensor batch_to_space_nd(T input, int[] block_shape, int[,] crops, string name = null) => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); + /// + /// Apply boolean mask to tensor. + /// + /// + /// + /// N-D tensor. + /// K-D boolean tensor, K <= N and K must be known statically. + /// + /// A 0-D int Tensor representing the axis in tensor to mask from. + /// (N-K+1)-dimensional tensor populated by entries in tensor corresponding to True values in mask. + public Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) + => array_ops.boolean_mask(tensor, mask, name: name, axis: axis); + public Tensor check_numerics(Tensor tensor, string message, string name = null) => gen_array_ops.check_numerics(tensor, message, name: name); diff --git a/src/TensorFlowNET.Core/APIs/tf.control.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs similarity index 100% rename from src/TensorFlowNET.Core/APIs/tf.control.cs rename to src/TensorFlowNET.Core/APIs/tf.control_flow.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index fd6efd8a..fe790826 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -21,6 +21,9 @@ namespace Tensorflow public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) => state_ops.assign(@ref, value, validate_shape, use_locking, name); + public void device(string device_name) + => get_default_graph().device(device_name); + public object get_collection(string key, string scope = "") => get_default_graph().get_collection(key, scope: scope); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 4063453c..48cec7a9 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -288,6 +288,11 @@ namespace Tensorflow return op; } + public void device(string device_name) + { + throw new NotImplementedException(""); + } + private void _create_op_helper(Operation op, bool compute_device = true) { _record_op_seen_by_control_dependencies(op); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 5e82f50e..3e2276c6 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -17,6 +17,8 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; using static Tensorflow.Binding; namespace Tensorflow @@ -66,6 +68,44 @@ namespace Tensorflow }); } + public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) + { + return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate + { + var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor"); + var mask_tensor = ops.convert_to_tensor(mask, name: "mask"); + + var shape_mask = mask_tensor.TensorShape; + var ndims_mask = shape_mask.ndim; + var shape_tensor = tensor_tensor.TensorShape; + + if (ndims_mask < 1) + throw new ValueError("mask cannot be scalar."); + + var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 }); + var shape1 = concat(new[] + { + shape(tensor_tensor)[$":{axis}"], + tf.expand_dims(leading_size, 0), + shape(tensor_tensor)[$"{axis + ndims_mask}:"] + }, 0); + tensor_tensor = reshape(tensor, shape1); + var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); + var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); + var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); + tensor_tensor.set_shape(s2); + + mask_tensor = reshape(mask_tensor, new[] { -1 }); + return _apply_mask_1d(tensor_tensor, mask_tensor, axis); + }); + } + + private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) + { + var indices = squeeze(where(mask), axis: new[] { 1 }); + return gather(reshaped_tensor, indices, axis: axis); + } + public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { dtype = dtype.as_base_dtype(); @@ -336,7 +376,12 @@ namespace Tensorflow { if( x == null && y == null) { - throw new NotImplementedException("where"); + return tf_with(ops.name_scope(name, "Where", new { condition }), scope => + { + name = scope; + condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); + return gen_array_ops.where(condition: condition, name: name); + }); } else if(x != null && y != null) { diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 5d037a09..59b43766 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -274,9 +274,10 @@ namespace Tensorflow return _op.outputs; } - public static Tensor where() + public static Tensor where(Tensor condition, string name = null) { - throw new NotImplementedException("where"); + var _op = _op_def_lib._apply_op_helper("Where", name, new { input = condition }); + return _op.output; } public static Tensor one_hot(Tensor indices, int depth, diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 90b1b80d..fe0dc5e9 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -23,6 +23,7 @@ namespace Tensorflow { public static class dtypes { + public static TF_DataType @bool = TF_DataType.TF_BOOL; public static TF_DataType int8 = TF_DataType.TF_INT8; public static TF_DataType int32 = TF_DataType.TF_INT32; public static TF_DataType int64 = TF_DataType.TF_INT64; diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index b73f15a9..fe68d718 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -260,5 +260,18 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray())); } } + + [TestMethod] + public void boolean_mask() + { + var tensor = new[] { 0, 1, 2, 3 }; + var mask = np.array(new[] { true, false, true, false }); + var masked = tf.boolean_mask(tensor, mask); + using (var sess = tf.Session()) + { + var result = sess.run(masked); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, result.ToArray())); + } + } } } \ No newline at end of file