|
|
|
@@ -17,6 +17,8 @@ |
|
|
|
using NumSharp;
|
|
|
|
using System;
|
|
|
|
using System.Collections.Generic;
|
|
|
|
using System.Linq;
|
|
|
|
using Tensorflow.Framework;
|
|
|
|
using static Tensorflow.Binding;
|
|
|
|
|
|
|
|
namespace Tensorflow
|
|
|
|
@@ -66,6 +68,44 @@ namespace Tensorflow |
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
|
|
|
|
{
|
|
|
|
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate
|
|
|
|
{
|
|
|
|
var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor");
|
|
|
|
var mask_tensor = ops.convert_to_tensor(mask, name: "mask");
|
|
|
|
|
|
|
|
var shape_mask = mask_tensor.TensorShape;
|
|
|
|
var ndims_mask = shape_mask.ndim;
|
|
|
|
var shape_tensor = tensor_tensor.TensorShape;
|
|
|
|
|
|
|
|
if (ndims_mask < 1)
|
|
|
|
throw new ValueError("mask cannot be scalar.");
|
|
|
|
|
|
|
|
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 });
|
|
|
|
var shape1 = concat(new[]
|
|
|
|
{
|
|
|
|
shape(tensor_tensor)[$":{axis}"],
|
|
|
|
tf.expand_dims(leading_size, 0),
|
|
|
|
shape(tensor_tensor)[$"{axis + ndims_mask}:"]
|
|
|
|
}, 0);
|
|
|
|
tensor_tensor = reshape(tensor, shape1);
|
|
|
|
var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First();
|
|
|
|
var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray());
|
|
|
|
var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray());
|
|
|
|
tensor_tensor.set_shape(s2);
|
|
|
|
|
|
|
|
mask_tensor = reshape(mask_tensor, new[] { -1 });
|
|
|
|
return _apply_mask_1d(tensor_tensor, mask_tensor, axis);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
|
|
|
|
{
|
|
|
|
var indices = squeeze(where(mask), axis: new[] { 1 });
|
|
|
|
return gather(reshaped_tensor, indices, axis: axis);
|
|
|
|
}
|
|
|
|
|
|
|
|
public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
|
|
|
|
{
|
|
|
|
dtype = dtype.as_base_dtype();
|
|
|
|
@@ -336,7 +376,12 @@ namespace Tensorflow |
|
|
|
{
|
|
|
|
if( x == null && y == null)
|
|
|
|
{
|
|
|
|
throw new NotImplementedException("where");
|
|
|
|
return tf_with(ops.name_scope(name, "Where", new { condition }), scope =>
|
|
|
|
{
|
|
|
|
name = scope;
|
|
|
|
condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition");
|
|
|
|
return gen_array_ops.where(condition: condition, name: name);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
else if(x != null && y != null)
|
|
|
|
{
|
|
|
|
|