| @@ -39,6 +39,19 @@ namespace Tensorflow | |||||
| public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null) | public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null) | ||||
| => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); | => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); | ||||
| /// <summary> | |||||
| /// Apply boolean mask to tensor. | |||||
| /// </summary> | |||||
| /// <typeparam name="T1"></typeparam> | |||||
| /// <typeparam name="T2"></typeparam> | |||||
| /// <param name="tensor">N-D tensor.</param> | |||||
| /// <param name="mask">K-D boolean tensor, K <= N and K must be known statically.</param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="axis">A 0-D int Tensor representing the axis in tensor to mask from. </param> | |||||
| /// <returns>(N-K+1)-dimensional tensor populated by entries in tensor corresponding to True values in mask.</returns> | |||||
| public Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | |||||
| => array_ops.boolean_mask(tensor, mask, name: name, axis: axis); | |||||
| public Tensor check_numerics(Tensor tensor, string message, string name = null) | public Tensor check_numerics(Tensor tensor, string message, string name = null) | ||||
| => gen_array_ops.check_numerics(tensor, message, name: name); | => gen_array_ops.check_numerics(tensor, message, name: name); | ||||
| @@ -21,6 +21,9 @@ namespace Tensorflow | |||||
| public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) | public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) | ||||
| => state_ops.assign(@ref, value, validate_shape, use_locking, name); | => state_ops.assign(@ref, value, validate_shape, use_locking, name); | ||||
| public void device(string device_name) | |||||
| => get_default_graph().device(device_name); | |||||
| public object get_collection(string key, string scope = "") | public object get_collection(string key, string scope = "") | ||||
| => get_default_graph().get_collection(key, scope: scope); | => get_default_graph().get_collection(key, scope: scope); | ||||
| @@ -288,6 +288,11 @@ namespace Tensorflow | |||||
| return op; | return op; | ||||
| } | } | ||||
| public void device(string device_name) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| private void _create_op_helper(Operation op, bool compute_device = true) | private void _create_op_helper(Operation op, bool compute_device = true) | ||||
| { | { | ||||
| _record_op_seen_by_control_dependencies(op); | _record_op_seen_by_control_dependencies(op); | ||||
| @@ -17,6 +17,8 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using Tensorflow.Framework; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -66,6 +68,44 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate | |||||
| { | |||||
| var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor"); | |||||
| var mask_tensor = ops.convert_to_tensor(mask, name: "mask"); | |||||
| var shape_mask = mask_tensor.TensorShape; | |||||
| var ndims_mask = shape_mask.ndim; | |||||
| var shape_tensor = tensor_tensor.TensorShape; | |||||
| if (ndims_mask < 1) | |||||
| throw new ValueError("mask cannot be scalar."); | |||||
| var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 }); | |||||
| var shape1 = concat(new[] | |||||
| { | |||||
| shape(tensor_tensor)[$":{axis}"], | |||||
| tf.expand_dims(leading_size, 0), | |||||
| shape(tensor_tensor)[$"{axis + ndims_mask}:"] | |||||
| }, 0); | |||||
| tensor_tensor = reshape(tensor, shape1); | |||||
| var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); | |||||
| var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); | |||||
| var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); | |||||
| tensor_tensor.set_shape(s2); | |||||
| mask_tensor = reshape(mask_tensor, new[] { -1 }); | |||||
| return _apply_mask_1d(tensor_tensor, mask_tensor, axis); | |||||
| }); | |||||
| } | |||||
| private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) | |||||
| { | |||||
| var indices = squeeze(where(mask), axis: new[] { 1 }); | |||||
| return gather(reshaped_tensor, indices, axis: axis); | |||||
| } | |||||
| public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| { | { | ||||
| dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
| @@ -336,7 +376,12 @@ namespace Tensorflow | |||||
| { | { | ||||
| if( x == null && y == null) | if( x == null && y == null) | ||||
| { | { | ||||
| throw new NotImplementedException("where"); | |||||
| return tf_with(ops.name_scope(name, "Where", new { condition }), scope => | |||||
| { | |||||
| name = scope; | |||||
| condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); | |||||
| return gen_array_ops.where(condition: condition, name: name); | |||||
| }); | |||||
| } | } | ||||
| else if(x != null && y != null) | else if(x != null && y != null) | ||||
| { | { | ||||
| @@ -274,9 +274,10 @@ namespace Tensorflow | |||||
| return _op.outputs; | return _op.outputs; | ||||
| } | } | ||||
| public static Tensor where() | |||||
| public static Tensor where(Tensor condition, string name = null) | |||||
| { | { | ||||
| throw new NotImplementedException("where"); | |||||
| var _op = _op_def_lib._apply_op_helper("Where", name, new { input = condition }); | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor one_hot(Tensor indices, int depth, | public static Tensor one_hot(Tensor indices, int depth, | ||||
| @@ -23,6 +23,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static class dtypes | public static class dtypes | ||||
| { | { | ||||
| public static TF_DataType @bool = TF_DataType.TF_BOOL; | |||||
| public static TF_DataType int8 = TF_DataType.TF_INT8; | public static TF_DataType int8 = TF_DataType.TF_INT8; | ||||
| public static TF_DataType int32 = TF_DataType.TF_INT32; | public static TF_DataType int32 = TF_DataType.TF_INT32; | ||||
| public static TF_DataType int64 = TF_DataType.TF_INT64; | public static TF_DataType int64 = TF_DataType.TF_INT64; | ||||
| @@ -260,5 +260,18 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void boolean_mask() | |||||
| { | |||||
| var tensor = new[] { 0, 1, 2, 3 }; | |||||
| var mask = np.array(new[] { true, false, true, false }); | |||||
| var masked = tf.boolean_mask(tensor, mask); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(masked); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, result.ToArray<int>())); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||