Merge pull request !313 from VectorSL/float_statustags/v0.2.0-alpha
| @@ -0,0 +1,138 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/cuda_runtime.h" | |||
| #include "kernel/gpu/cuda_impl/float_status_impl.cuh" | |||
| template <typename T> | |||
| __global__ void IsNan(const size_t size, const T* input, bool* out) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (isnan(input[pos])) { | |||
| out[pos] = true; | |||
| } else { | |||
| out[pos] = false; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void IsNan(const size_t size, const half* input, bool* out) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (__hisnan(input[pos])) { | |||
| out[pos] = true; | |||
| } else { | |||
| out[pos] = false; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void IsInf(const size_t size, const T* input, bool* out) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (isinf(input[pos]) != 0) { | |||
| out[pos] = true; | |||
| } else { | |||
| out[pos] = false; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void IsInf(const size_t size, const half* input, bool* out) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (__hisinf(input[pos]) != 0) { | |||
| out[pos] = true; | |||
| } else { | |||
| out[pos] = false; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void IsFinite(const size_t size, const T* input, bool* out) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (isinf(input[pos]) == 0 && !isnan(input[pos])) { | |||
| out[pos] = true; | |||
| } else { | |||
| out[pos] = false; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void IsFinite(const size_t size, const half* input, bool* out) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { | |||
| out[pos] = true; | |||
| } else { | |||
| out[pos] = false; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void FloatStatus(const size_t size, const T* input, T* out) { | |||
| out[0] = 0; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (isinf(input[pos]) != 0 || isnan(input[pos])) { | |||
| out[0] = 1; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void FloatStatus(const size_t size, const half* input, half* out) { | |||
| out[0] = 0; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { | |||
| out[0] = 1; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { | |||
| FloatStatus<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { | |||
| IsNan<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { | |||
| IsInf<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { | |||
| IsFinite<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); | |||
| return; | |||
| } | |||
| template void CalFloatStatus<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); | |||
| template void CalFloatStatus<half>(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); | |||
| template void CalIsInf<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); | |||
| template void CalIsInf<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); | |||
| template void CalIsNan<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); | |||
| template void CalIsNan<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); | |||
| template void CalIsFinite<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); | |||
| template void CalIsFinite<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); | |||
| template <typename T> | |||
| void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); | |||
| template <typename T> | |||
| void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); | |||
| template <typename T> | |||
| void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/math/float_status_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| FloatStatusGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| FloatStatusGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| FloatStatusGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), | |||
| FloatStatusGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| FloatStatusGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), | |||
| FloatStatusGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| FloatStatusGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), | |||
| FloatStatusGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,130 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <string> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/float_status_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; | |||
| static const std::map<std::string, Optype> kOpTypeMap = { | |||
| {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; | |||
| template <typename T> | |||
| class FloatStatusGpuKernel : public GpuKernel { | |||
| public: | |||
| FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} | |||
| ~FloatStatusGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| switch (kernel_name_) { | |||
| case OP_STATUS: { | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case OP_INF: { | |||
| bool *output = GetDeviceAddress<bool>(outputs, 0); | |||
| CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case OP_NAN: { | |||
| bool *output = GetDeviceAddress<bool>(outputs, 0); | |||
| CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| case OP_FINITE: { | |||
| bool *output = GetDeviceAddress<bool>(outputs, 0); | |||
| CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| break; | |||
| } | |||
| default: { | |||
| MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| if (!CheckParam(kernel_node)) { | |||
| return false; | |||
| } | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input_size_ = sizeof(T); | |||
| for (size_t x : shape) { | |||
| input_size_ = input_size_ * x; | |||
| } | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto iter = kOpTypeMap.find(kernel_name); | |||
| if (iter == kOpTypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; | |||
| } else { | |||
| kernel_name_ = iter->second; | |||
| } | |||
| if (kernel_name_ == OP_STATUS) { | |||
| output_size_ = sizeof(T); | |||
| } else { | |||
| output_size_ = input_size_ / sizeof(T) * sizeof(bool); | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| } | |||
| private: | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| Optype kernel_name_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H | |||
| @@ -0,0 +1,118 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.status = P.FloatStatus() | |||
| def construct(self, x): | |||
| return self.status(x) | |||
| class Netnan(nn.Cell): | |||
| def __init__(self): | |||
| super(Netnan, self).__init__() | |||
| self.isnan = P.IsNan() | |||
| def construct(self, x): | |||
| return self.isnan(x) | |||
| class Netinf(nn.Cell): | |||
| def __init__(self): | |||
| super(Netinf, self).__init__() | |||
| self.isinf = P.IsInf() | |||
| def construct(self, x): | |||
| return self.isinf(x) | |||
| class Netfinite(nn.Cell): | |||
| def __init__(self): | |||
| super(Netfinite, self).__init__() | |||
| self.isfinite = P.IsFinite() | |||
| def construct(self, x): | |||
| return self.isfinite(x) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x1 = np.array([[1.2, 2, np.nan, 88]]).astype(np.float32) | |||
| x2 = np.array([[np.inf, 1, 88.0, 0]]).astype(np.float32) | |||
| x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_status(): | |||
| ms_status = Net(); | |||
| output1 = ms_status(Tensor(x1)) | |||
| output2 = ms_status(Tensor(x2)) | |||
| output3 = ms_status(Tensor(x3)) | |||
| expect1 = 1 | |||
| expect2 = 1 | |||
| expect3 = 0 | |||
| assert output1.asnumpy()[0] == expect1 | |||
| assert output2.asnumpy()[0] == expect2 | |||
| assert output3.asnumpy()[0] == expect3 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nan(): | |||
| ms_isnan = Netnan(); | |||
| output1 = ms_isnan(Tensor(x1)) | |||
| output2 = ms_isnan(Tensor(x2)) | |||
| output3 = ms_isnan(Tensor(x3)) | |||
| expect1 = [[False, False, True, False]] | |||
| expect2 = [[False, False, False, False]] | |||
| expect3 = [[False, False], [False, False], [False, False]] | |||
| assert (output1.asnumpy() == expect1).all() | |||
| assert (output2.asnumpy() == expect2).all() | |||
| assert (output3.asnumpy() == expect3).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_inf(): | |||
| ms_isinf = Netinf(); | |||
| output1 = ms_isinf(Tensor(x1)) | |||
| output2 = ms_isinf(Tensor(x2)) | |||
| output3 = ms_isinf(Tensor(x3)) | |||
| expect1 = [[False, False, False, False]] | |||
| expect2 = [[True, False, False, False]] | |||
| expect3 = [[False, False], [False, False], [False, False]] | |||
| assert (output1.asnumpy() == expect1).all() | |||
| assert (output2.asnumpy() == expect2).all() | |||
| assert (output3.asnumpy() == expect3).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_finite(): | |||
| ms_isfinite = Netfinite(); | |||
| output1 = ms_isfinite(Tensor(x1)) | |||
| output2 = ms_isfinite(Tensor(x2)) | |||
| output3 = ms_isfinite(Tensor(x3)) | |||
| expect1 = [[True, True, False, True]] | |||
| expect2 = [[False, True, True, True]] | |||
| expect3 = [[True, True], [True, True], [True, True]] | |||
| assert (output1.asnumpy() == expect1).all() | |||
| assert (output2.asnumpy() == expect2).all() | |||
| assert (output3.asnumpy() == expect3).all() | |||