| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include <random> | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| #include "backend/kernel_compiler/cpu/dropout_cpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void DropoutCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| CheckParam(kernel_node); | |||||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| mask_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 1); | |||||
| keep_prob_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "keep_prob"); | |||||
| if (keep_prob_ <= 0.0) { | |||||
| MS_LOG(EXCEPTION) << "Keep_prob is smaller or equal to zero but DropoutCPUKernel needs greater than 0"; | |||||
| } | |||||
| if (keep_prob_ > 1.0) { | |||||
| MS_LOG(EXCEPTION) << "Keep_prob greater than one but DropoutCPUKernel needs smaller or equal to one"; | |||||
| } | |||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| for (const uint64_t &d : input_shape_) { | |||||
| tensor_size_ *= d; | |||||
| } | |||||
| } | |||||
| bool DropoutCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (dtype_ == kNumberTypeFloat16) { | |||||
| LaunchKernel<float16>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| LaunchKernel<float>(inputs, outputs); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| void DropoutCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||||
| auto input_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| auto mask_addr = reinterpret_cast<T *>(outputs[1]->addr); | |||||
| std::random_device rd; | |||||
| std::mt19937 gen(rd()); | |||||
| std::bernoulli_distribution dis(keep_prob_); | |||||
| T scale = (T)(1.f / keep_prob_); | |||||
| for (uint64_t i = 0; i < tensor_size_; ++i) { | |||||
| mask_addr[i] = (T)dis(gen); | |||||
| output_addr[i] = mask_addr[i] * input_addr[i] * scale; | |||||
| } | |||||
| } | |||||
| void DropoutCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 1) { | |||||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DropoutCPUKernel needs 1 input."; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 2) { | |||||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DropoutCPUKernel needs 1 output."; | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class DropoutCPUKernel : public CPUKernel { | |||||
| public: | |||||
| DropoutCPUKernel() = default; | |||||
| ~DropoutCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| template <typename T> | |||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||||
| private: | |||||
| void CheckParam(const CNodePtr &kernel_node); | |||||
| std::vector<size_t> input_shape_; | |||||
| std::vector<size_t> output_shape_; | |||||
| std::vector<size_t> mask_shape_; | |||||
| TypeId dtype_{kTypeUnknown}; | |||||
| float keep_prob_ = 0.0; | |||||
| uint64_t tensor_size_ = 1; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL( | |||||
| Dropout, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| DropoutCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| Dropout, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| DropoutCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ | |||||
| @@ -73,7 +73,7 @@ class Dropout(Cell): | |||||
| Tensor, output tensor with the same shape as the input. | Tensor, output tensor with the same shape as the input. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) | >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) | ||||
| @@ -102,14 +102,14 @@ class Dropout(Cell): | |||||
| self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) | self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) | ||||
| self.dropout_do_mask = P.DropoutDoMask() | self.dropout_do_mask = P.DropoutDoMask() | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.is_gpu = context.get_context('device_target') in ["GPU"] | |||||
| self.is_ascend = context.get_context('device_target') in ["Ascend"] | |||||
| self.dropout = P.Dropout(keep_prob) | self.dropout = P.Dropout(keep_prob) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| if not self.training: | if not self.training: | ||||
| return x | return x | ||||
| if self.is_gpu: | |||||
| if not self.is_ascend: | |||||
| out, _ = self.dropout(x) | out, _ = self.dropout(x) | ||||
| return out | return out | ||||
| @@ -0,0 +1,93 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.dropout = P.Dropout() | |||||
| def construct(self, x): | |||||
| return self.dropout(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_net(): | |||||
| x = np.random.randn(3, 3, 4).astype(np.float32) | |||||
| dropout = Net() | |||||
| output, mask = dropout(Tensor(x)) | |||||
| print(x) | |||||
| print(output) | |||||
| print(mask) | |||||
| class Net1(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net1, self).__init__() | |||||
| self.dropout = P.Dropout(keep_prob=0.1) | |||||
| def construct(self, x): | |||||
| return self.dropout(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_net1(): | |||||
| x = np.arange(0, 16).reshape(2, 2, 4).astype(np.float32) | |||||
| dropout = Net1() | |||||
| output, mask = dropout(Tensor(x)) | |||||
| print(x) | |||||
| print(output) | |||||
| print(mask) | |||||
| class Net2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self.dropout = P.Dropout(keep_prob=1.0) | |||||
| def construct(self, x): | |||||
| return self.dropout(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_net2(): | |||||
| x = np.arange(0, 12).reshape(3, 4).astype(np.float16) | |||||
| dropout = Net2() | |||||
| output, mask = dropout(Tensor(x)) | |||||
| print(x) | |||||
| print(output) | |||||
| print(mask) | |||||
| if __name__ == '__main__': | |||||
| test_net() | |||||
| test_net1() | |||||
| test_net2() | |||||