From: @TFbunny Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosmanpull/15190/MERGE
| @@ -53,5 +53,12 @@ MS_REG_GPU_KERNEL_ONE(Select, | |||||
| .AddInputAttr(kNumberTypeInt64) | .AddInputAttr(kNumberTypeInt64) | ||||
| .AddOutputAttr(kNumberTypeInt64), | .AddOutputAttr(kNumberTypeInt64), | ||||
| SelectGpuKernel, int64_t) | SelectGpuKernel, int64_t) | ||||
| MS_REG_GPU_KERNEL_ONE(Select, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeBool) | |||||
| .AddInputAttr(kNumberTypeBool) | |||||
| .AddInputAttr(kNumberTypeBool) | |||||
| .AddOutputAttr(kNumberTypeBool), | |||||
| SelectGpuKernel, bool) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -44,4 +44,5 @@ template void CalSelect<half>(const size_t size, const bool* cond, const half* i | |||||
| half* output, cudaStream_t cuda_stream); | half* output, cudaStream_t cuda_stream); | ||||
| template void CalSelect<int64_t>(const size_t size, const bool* cond, const int64_t* input_X, const int64_t* input_y, | template void CalSelect<int64_t>(const size_t size, const bool* cond, const int64_t* input_X, const int64_t* input_y, | ||||
| int64_t* output, cudaStream_t cuda_stream); | int64_t* output, cudaStream_t cuda_stream); | ||||
| template void CalSelect<bool>(const size_t size, const bool *cond, const bool *input_X, const bool *input_y, | |||||
| bool *output, cudaStream_t cuda_stream); | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -14,12 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, | void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SELECT_IMPL_H_ | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -21,7 +21,6 @@ import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| @@ -31,20 +30,32 @@ class Net(nn.Cell): | |||||
| return self.select(cond_op, input_x, input_y) | return self.select(cond_op, input_x, input_y) | ||||
| cond = np.array([[True, False], [True, False]]).astype(np.bool) | |||||
| x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) | |||||
| y = np.array([[1, 2], [3, 4.0]]).astype(np.float32) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_select(): | def test_select(): | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| select = Net() | select = Net() | ||||
| cond = np.array([[True, False], [True, False]]).astype(np.bool) | |||||
| x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) | |||||
| y = np.array([[1, 2], [3, 4.0]]).astype(np.float32) | |||||
| output = select(Tensor(cond), Tensor(x), Tensor(y)) | output = select(Tensor(cond), Tensor(x), Tensor(y)) | ||||
| expect = [[1.2, 2], [1, 4.0]] | expect = [[1.2, 2], [1, 4.0]] | ||||
| error = np.ones(shape=[2, 2]) * 1.0e-6 | error = np.ones(shape=[2, 2]) * 1.0e-6 | ||||
| diff = output.asnumpy() - expect | diff = output.asnumpy() - expect | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| assert np.all(-diff < error) | assert np.all(-diff < error) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| x = np.array([[1, 0], [1, 0]]).astype(np.bool) | |||||
| y = np.array([[0, 0], [1, 1]]).astype(np.bool) | |||||
| output = select(Tensor(cond), Tensor(x), Tensor(y)) | |||||
| expect = np.array([[1, 0], [1, 1]]).astype(np.bool) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| x = np.array([[1, 0], [1, 0]]).astype(np.bool) | |||||
| y = np.array([[0, 0], [1, 1]]).astype(np.bool) | |||||
| output = select(Tensor(cond), Tensor(x), Tensor(y)) | |||||
| expect = np.array([[1, 0], [1, 1]]).astype(np.bool) | |||||
| assert np.all(output.asnumpy() == expect) | |||||