| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <random> | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "backend/kernel_compiler/cpu/dropout_cpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void DropoutCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| mask_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 1); | |||
| keep_prob_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "keep_prob"); | |||
| if (keep_prob_ <= 0.0) { | |||
| MS_LOG(EXCEPTION) << "Keep_prob is smaller or equal to zero but DropoutCPUKernel needs greater than 0"; | |||
| } | |||
| if (keep_prob_ > 1.0) { | |||
| MS_LOG(EXCEPTION) << "Keep_prob greater than one but DropoutCPUKernel needs smaller or equal to one"; | |||
| } | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| for (const uint64_t &d : input_shape_) { | |||
| tensor_size_ *= d; | |||
| } | |||
| } | |||
| bool DropoutCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void DropoutCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| auto input_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto mask_addr = reinterpret_cast<T *>(outputs[1]->addr); | |||
| std::random_device rd; | |||
| std::mt19937 gen(rd()); | |||
| std::bernoulli_distribution dis(keep_prob_); | |||
| T scale = (T)(1.f / keep_prob_); | |||
| for (uint64_t i = 0; i < tensor_size_; ++i) { | |||
| mask_addr[i] = (T)dis(gen); | |||
| output_addr[i] = mask_addr[i] * input_addr[i] * scale; | |||
| } | |||
| } | |||
| void DropoutCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DropoutCPUKernel needs 1 input."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 2) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DropoutCPUKernel needs 1 output."; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class DropoutCPUKernel : public CPUKernel { | |||
| public: | |||
| DropoutCPUKernel() = default; | |||
| ~DropoutCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> output_shape_; | |||
| std::vector<size_t> mask_shape_; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| float keep_prob_ = 0.0; | |||
| uint64_t tensor_size_ = 1; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| Dropout, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| DropoutCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Dropout, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| DropoutCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_ | |||
| @@ -73,7 +73,7 @@ class Dropout(Cell): | |||
| Tensor, output tensor with the same shape as the input. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) | |||
| @@ -102,14 +102,14 @@ class Dropout(Cell): | |||
| self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) | |||
| self.dropout_do_mask = P.DropoutDoMask() | |||
| self.cast = P.Cast() | |||
| self.is_gpu = context.get_context('device_target') in ["GPU"] | |||
| self.is_ascend = context.get_context('device_target') in ["Ascend"] | |||
| self.dropout = P.Dropout(keep_prob) | |||
| def construct(self, x): | |||
| if not self.training: | |||
| return x | |||
| if self.is_gpu: | |||
| if not self.is_ascend: | |||
| out, _ = self.dropout(x) | |||
| return out | |||
| @@ -0,0 +1,93 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.dropout = P.Dropout() | |||
| def construct(self, x): | |||
| return self.dropout(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_net(): | |||
| x = np.random.randn(3, 3, 4).astype(np.float32) | |||
| dropout = Net() | |||
| output, mask = dropout(Tensor(x)) | |||
| print(x) | |||
| print(output) | |||
| print(mask) | |||
| class Net1(nn.Cell): | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.dropout = P.Dropout(keep_prob=0.1) | |||
| def construct(self, x): | |||
| return self.dropout(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_net1(): | |||
| x = np.arange(0, 16).reshape(2, 2, 4).astype(np.float32) | |||
| dropout = Net1() | |||
| output, mask = dropout(Tensor(x)) | |||
| print(x) | |||
| print(output) | |||
| print(mask) | |||
| class Net2(nn.Cell): | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.dropout = P.Dropout(keep_prob=1.0) | |||
| def construct(self, x): | |||
| return self.dropout(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_net2(): | |||
| x = np.arange(0, 12).reshape(3, 4).astype(np.float16) | |||
| dropout = Net2() | |||
| output, mask = dropout(Tensor(x)) | |||
| print(x) | |||
| print(output) | |||
| print(mask) | |||
| if __name__ == '__main__': | |||
| test_net() | |||
| test_net1() | |||
| test_net2() | |||