| @@ -31,3 +31,4 @@ from .dropout_grad import expand_dropoutgrad | |||||
| from .layernorm_grad import expand_layernormgrad | from .layernorm_grad import expand_layernormgrad | ||||
| from .logsoftmax import expand_logsoftmax | from .logsoftmax import expand_logsoftmax | ||||
| from .logsoftmax_grad import expand_logsoftmaxgrad | from .logsoftmax_grad import expand_logsoftmaxgrad | ||||
| from .gkdropout import expand_gkdropout | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for GkDropOut""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| def expand_gkdropout(expand_info): | |||||
| """GkDropOut expander""" | |||||
| # get op info. | |||||
| input_desc = expand_info['input_desc'][0] | |||||
| maks_desc = expand_info['input_desc'][1] | |||||
| keep_prob = None | |||||
| for attr in expand_info['attr']: | |||||
| if 'keep_prob' in attr: | |||||
| keep_prob = attr['keep_prob'] | |||||
| if keep_prob is None: | |||||
| raise RuntimeError("keep_prob does not exist in attrs.") | |||||
| # generate a graph. | |||||
| graph_builder = builder.GraphBuilder() | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | |||||
| input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format']) | |||||
| graph_scope.set_input(input_x, input_mask) | |||||
| keep_prob_v = graph_builder.value(input_x.dtype, keep_prob, "DefaultFormat") | |||||
| r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob, "DefaultFormat") | |||||
| mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) | |||||
| mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) | |||||
| # compute result | |||||
| result = graph_builder.emit('Mul', [r_keep_prob, input_x]) | |||||
| result = graph_builder.emit('Mul', [result, mask]) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result, mask) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -29,5 +29,7 @@ MS_REG_GPU_KERNEL_ONE(UniformInt, | |||||
| RandomOpGpuKernel, int) | RandomOpGpuKernel, int) | ||||
| MS_REG_GPU_KERNEL_ONE(UniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_GPU_KERNEL_ONE(UniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | ||||
| RandomOpGpuKernel, float) | RandomOpGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(CudnnUniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||||
| RandomOpGpuKernel, float) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,13 +25,22 @@ | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh" | ||||
| #include "include/curand.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_INT, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; | |||||
| enum RandomOptype { | |||||
| RANDOM_OP_NORMAL = 0, | |||||
| RANDOM_OP_UNIFORM_INT, | |||||
| RANDOM_OP_UNIFORM_REAL, | |||||
| RANDOM_OP_CUDNN_UNIFORM_REAL, | |||||
| RANDOM_OP_INVALID_TYPE = 255 | |||||
| }; | |||||
| const std::map<std::string, RandomOptype> kRandomOpTypeMap = { | |||||
| {"StandardNormal", RANDOM_OP_NORMAL}, {"UniformInt", RANDOM_OP_UNIFORM_INT}, {"UniformReal", RANDOM_OP_UNIFORM_REAL}}; | |||||
| const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}, | |||||
| {"UniformInt", RANDOM_OP_UNIFORM_INT}, | |||||
| {"UniformReal", RANDOM_OP_UNIFORM_REAL}, | |||||
| {"CudnnUniformReal", RANDOM_OP_CUDNN_UNIFORM_REAL}}; | |||||
| template <typename T> | template <typename T> | ||||
| class RandomOpGpuKernel : public GpuKernel { | class RandomOpGpuKernel : public GpuKernel { | ||||
| @@ -76,6 +85,23 @@ class RandomOpGpuKernel : public GpuKernel { | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| break; | break; | ||||
| } | } | ||||
| case RANDOM_OP_CUDNN_UNIFORM_REAL: { | |||||
| float *mask_f = GetDeviceAddress<float>(outputs, 0); | |||||
| if (!states_init_) { | |||||
| CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT), | |||||
| "Failed to create generator"); | |||||
| CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, seed_), | |||||
| "Failed to SetPseudoRandomGeneratorSeed"); | |||||
| MS_EXCEPTION_IF_NULL(mask_generator_); | |||||
| states_init_ = true; | |||||
| } | |||||
| CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Failed to set stream for generator"); | |||||
| // curandGen only support float or double for mask. | |||||
| CHECK_CURAND_RET_WITH_EXCEPT(curandGenerateUniform(mask_generator_, mask_f, outputs[0]->size / sizeof(float)), | |||||
| "Failed to generate uniform"); | |||||
| break; | |||||
| } | |||||
| default: { | default: { | ||||
| MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; | MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; | ||||
| } | } | ||||
| @@ -148,6 +174,8 @@ class RandomOpGpuKernel : public GpuKernel { | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| curandGenerator_t mask_generator_; | |||||
| bool states_init_{false}; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,7 @@ constexpr size_t kMulInputNum = 3; | |||||
| constexpr size_t kRsqrtInputNum = 2; | constexpr size_t kRsqrtInputNum = 2; | ||||
| constexpr size_t kSubInputNum = 3; | constexpr size_t kSubInputNum = 3; | ||||
| constexpr size_t kAssignSubInputNum = 3; | constexpr size_t kAssignSubInputNum = 3; | ||||
| constexpr size_t kDropoutInputNum = 2; | |||||
| constexpr size_t kConvBn1OutputNum = 3; | constexpr size_t kConvBn1OutputNum = 3; | ||||
| constexpr size_t kBn2ReluOutputNum = 4; | constexpr size_t kBn2ReluOutputNum = 4; | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "backend/kernel_compiler/common_utils.h" | #include "backend/kernel_compiler/common_utils.h" | ||||
| #include "backend/kernel_compiler/kernel_build_info.h" | #include "backend/kernel_compiler/kernel_build_info.h" | ||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | ||||
| #include "backend/optimizer/graph_kernel/substitute_dropout.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "mindspore/core/ir/graph_utils.h" | #include "mindspore/core/ir/graph_utils.h" | ||||
| #include "pipeline/jit/parse/python_adapter.h" | #include "pipeline/jit/parse/python_adapter.h" | ||||
| @@ -242,6 +243,10 @@ void GraphKernelExpander::ToPrimitive(const FuncGraphPtr &func_graph) const { | |||||
| bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | ||||
| expand_ops_ = GetExpandOps(); | expand_ops_ = GetExpandOps(); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| if (expand_ops_.count(prim::kPrimGkDropout) > 0) { | |||||
| std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>(); | |||||
| pass->Run(func_graph); | |||||
| } | |||||
| auto mng = func_graph->manager(); | auto mng = func_graph->manager(); | ||||
| if (mng == nullptr) { | if (mng == nullptr) { | ||||
| mng = Manage(func_graph, true); | mng = Manage(func_graph, true); | ||||
| @@ -711,7 +711,8 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimTanhGrad, | prim::kPrimTanhGrad, | ||||
| prim::kPrimReduceMean, | prim::kPrimReduceMean, | ||||
| prim::kPrimMaximumGrad, | prim::kPrimMaximumGrad, | ||||
| prim::kPrimMinimumGrad | |||||
| prim::kPrimMinimumGrad, | |||||
| prim::kPrimGkDropout | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| return expand_ops; | return expand_ops; | ||||
| @@ -26,11 +26,15 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | |||||
| #include "backend/session/kernel_graph.h" | #include "backend/session/kernel_graph.h" | ||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" | ||||
| #include <nlohmann/json.hpp> | #include <nlohmann/json.hpp> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace prim { | |||||
| inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropout"); | |||||
| } // namespace prim | |||||
| namespace opt { | namespace opt { | ||||
| using kernel::DumpOption; | using kernel::DumpOption; | ||||
| @@ -0,0 +1,120 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/graph_kernel/substitute_dropout.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||||
| #include "runtime/device/kernel_info.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| unsigned int SubstituteDropout::seed_ = time(NULL); | |||||
| const BaseRef SubstituteDropout::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<Var>(); | |||||
| return VectorRef({prim::kPrimDropout, Xs}); | |||||
| } | |||||
| void SetNewKernelInfo(const CNodePtr &kernel_node) { | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<TypeId> inputs_type; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||||
| inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); | |||||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); | |||||
| } | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> outputs_type; | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||||
| outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||||
| } | |||||
| std::string origin_data_format = kOpFormat_DEFAULT; | |||||
| auto cnode_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| cnode_info_builder->SetOriginDataFormat(origin_data_format); | |||||
| cnode_info_builder->SetInputsFormat(inputs_format); | |||||
| cnode_info_builder->SetInputsDeviceType(inputs_type); | |||||
| cnode_info_builder->SetOutputsFormat(outputs_format); | |||||
| cnode_info_builder->SetOutputsDeviceType(outputs_type); | |||||
| cnode_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE); | |||||
| cnode_info_builder->SetProcessor(kernel::Processor::CUDA); | |||||
| auto cnode_selected_info = cnode_info_builder->Build(); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(cnode_selected_info, kernel_node.get()); | |||||
| } | |||||
| const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->inputs().size() < kDropoutInputNum) { | |||||
| MS_LOG(EXCEPTION) << "Dropout's input num is wrong"; | |||||
| } | |||||
| AbstractBasePtr old_abstract = cnode->abstract()->Clone(); | |||||
| auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0); | |||||
| ShapeVector shape_i64; | |||||
| std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); | |||||
| // Create new tensor | |||||
| AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal)}; | |||||
| auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), | |||||
| static_cast<void *>(&shape[0]), kNumberTypeInt64); | |||||
| uniform_input.push_back(NewValueNode(tensor)); | |||||
| uniform_input[1]->set_abstract(tensor->ToAbstract()); | |||||
| uniform_input[1]->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| std::string origin_data_format = kOpFormat_DEFAULT; | |||||
| std::vector<std::string> outputs_format = {origin_data_format}; | |||||
| std::vector<TypeId> outputs_type = {kNumberTypeInt32}; | |||||
| auto tensor_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| tensor_info_builder->SetOriginDataFormat(origin_data_format); | |||||
| tensor_info_builder->SetOutputsFormat(outputs_format); | |||||
| tensor_info_builder->SetOutputsDeviceType(outputs_type); | |||||
| tensor_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE); | |||||
| tensor_info_builder->SetProcessor(kernel::Processor::CUDA); | |||||
| auto tensor_selected_info = tensor_info_builder->Build(); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(tensor_selected_info, uniform_input[1].get()); | |||||
| // create new uniform_real_node | |||||
| auto uniform_real_node = func_graph->NewCNode(uniform_input); | |||||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(rand_r(&seed_)))); | |||||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(rand_r(&seed_)))); | |||||
| auto uniform_abstract = std::make_shared<abstract::AbstractTensor>(std::make_shared<Float>(32), shape_i64); | |||||
| uniform_real_node->set_abstract(uniform_abstract); | |||||
| uniform_real_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| SetNewKernelInfo(uniform_real_node); | |||||
| // create new_node, has two input, first is cnode->input[1], second is unifom_real_node | |||||
| AnfNodePtrList new_node_inputs = {NewValueNode(prim::kPrimGkDropout)}; | |||||
| new_node_inputs.push_back(cnode->input(1)); | |||||
| new_node_inputs.push_back(uniform_real_node); | |||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||||
| AnfAlgo::GetCNodePrimitive(new_node)->set_attr("keep_prob", AnfAlgo::GetCNodePrimitive(cnode)->GetAttr("keep_prob")); | |||||
| new_node->set_abstract(old_abstract); | |||||
| new_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| SetNewKernelInfo(new_node); | |||||
| return new_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class SubstituteDropout : public PatternProcessPass { | |||||
| public: | |||||
| explicit SubstituteDropout(bool multigraph = true) : PatternProcessPass("substitute_dropout", multigraph) {} | |||||
| ~SubstituteDropout() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| static unsigned int seed_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SUBSTITUTE_DROPOUT_H_ | |||||
| @@ -164,6 +164,9 @@ inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Pri | |||||
| inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | ||||
| inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | ||||
| inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad"); | inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad"); | ||||
| inline const PrimitivePtr kPrimDropout = std::make_shared<Primitive>("Dropout"); | |||||
| inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal"); | |||||
| inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal"); | |||||
| inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | ||||
| inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | ||||
| inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | ||||
| @@ -118,7 +118,6 @@ class StandardLaplace(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class Gamma(PrimitiveWithInfer): | class Gamma(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Produces random positive floating-point values x, distributed according to probability density function: | Produces random positive floating-point values x, distributed according to probability density function: | ||||
| @@ -532,6 +531,7 @@ class Multinomial(PrimitiveWithInfer): | |||||
| "value": None} | "value": None} | ||||
| return out | return out | ||||
| class UniformCandidateSampler(PrimitiveWithInfer): | class UniformCandidateSampler(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Uniform candidate sampler. | Uniform candidate sampler. | ||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, keep_prob): | |||||
| super(Net, self).__init__() | |||||
| self.drop = P.Dropout(keep_prob) | |||||
| def construct(self, x_): | |||||
| return self.drop(x_) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dropout(): | |||||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||||
| x_shape = [4096, 768] | |||||
| x = np.ones(x_shape).astype(np.float32) | |||||
| keep_prob = 0.9 | |||||
| dropout = Net(keep_prob) | |||||
| tx = Tensor(x) | |||||
| output, mask = dropout(tx) | |||||
| output_np = output.asnumpy() | |||||
| elem_count = x.size | |||||
| nonzero_count = np.count_nonzero(output_np) | |||||
| assert (elem_count * (keep_prob - 0.1)) < nonzero_count < (elem_count * (keep_prob + 0.1)) | |||||
| output_sum = np.sum(output_np) | |||||
| x_sum = np.sum(x) | |||||
| assert abs(output_sum - x_sum)/x_sum < 0.1 | |||||
| # check mask | |||||
| mask_np = mask.asnumpy() | |||||
| mask_sum = np.sum(mask_np) | |||||
| assert np.count_nonzero(mask_np) == nonzero_count | |||||
| assert abs(mask_sum - nonzero_count)/nonzero_count < 0.1 | |||||