| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -28,14 +28,7 @@ namespace kernel { | |||
| template <typename T> | |||
| class DropoutGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| DropoutGpuFwdKernel() | |||
| : cudnn_handle_(nullptr), | |||
| is_null_input_(false), | |||
| num_count_(0), | |||
| keep_prob_(0.0), | |||
| states_init_(false), | |||
| mask_generator_(nullptr) {} | |||
| DropoutGpuFwdKernel() { ResetResource(); } | |||
| ~DropoutGpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -96,6 +89,18 @@ class DropoutGpuFwdKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| cudnn_handle_ = nullptr; | |||
| is_null_input_ = false; | |||
| num_count_ = 0; | |||
| keep_prob_ = 0.0; | |||
| states_init_ = false; | |||
| mask_generator_ = nullptr; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| @@ -85,6 +85,8 @@ AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -573,6 +573,23 @@ AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const Primitiv | |||
| return std::make_shared<AbstractTuple>(args_list); | |||
| } | |||
| AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| ShapeVector shape = x->shape()->shape(); | |||
| ShapeVector min_shape = x->shape()->min_shape(); | |||
| ShapeVector max_shape = x->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(shape, &min_shape, &max_shape); | |||
| auto output_shape = | |||
| std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| AbstractBasePtrList ret = {output_shape, output_shape}; | |||
| return std::make_shared<AbstractTuple>(ret); | |||
| } | |||
| AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple and a tensor. | |||
| @@ -123,6 +123,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimBpropCut, {InferImplBpropCut, true}}, | |||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | |||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | |||
| {prim::kPrimDropout, {InferImplDropout, true}}, | |||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | |||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | |||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | |||
| @@ -6336,7 +6336,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): | |||
| return var_dtype, accum_dtype, linear_dtype | |||
| class Dropout(PrimitiveWithInfer): | |||
| class Dropout(PrimitiveWithCheck): | |||
| """ | |||
| During training, randomly zeroes some of the elements of the input tensor with probability. | |||
| @@ -6367,15 +6367,12 @@ class Dropout(PrimitiveWithInfer): | |||
| self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name) | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) | |||
| def infer_shape(self, x_shape): | |||
| def check_shape(self, x_shape): | |||
| validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name) | |||
| mask_shape = x_shape | |||
| return x_shape, mask_shape | |||
| def infer_dtype(self, x_dtype): | |||
| def check_dtype(self, x_dtype): | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||
| return x_dtype, x_dtype | |||
| class Dropout3d(PrimitiveWithInfer): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,8 +17,9 @@ import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| import mindspore.context as context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| class Net(nn.Cell): | |||
| def __init__(self, keep_prob): | |||
| @@ -52,3 +53,47 @@ def test_dropout(): | |||
| mask_sum = np.sum(mask_np) | |||
| assert np.count_nonzero(mask_np) == nonzero_count | |||
| assert abs(mask_sum - nonzero_count)/nonzero_count < 0.1 | |||
| class DropoutDynamic(nn.Cell): | |||
| def __init__(self, keep_prob): | |||
| super(DropoutDynamic, self).__init__() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.drop = P.Dropout(keep_prob) | |||
| def construct(self, x): | |||
| x = self.test_dynamic(x) | |||
| return self.drop(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dropout_dynamic(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_1 = np.ones([32, 16, 2, 5]).astype(np.float32) | |||
| x_2 = np.ones([32, 16, 2, 5, 6]).astype(np.float32) | |||
| keep_prob = 0.4 | |||
| net = DropoutDynamic(keep_prob) | |||
| output_1, mask_1 = net(Tensor(x_1)) | |||
| elem_count_1 = x_1.size | |||
| nonzero_count_1 = np.count_nonzero(output_1.asnumpy()) | |||
| assert (elem_count_1 * (keep_prob - 0.1)) < nonzero_count_1 < (elem_count_1 * (keep_prob + 0.1)) | |||
| output_sum_1 = np.sum(output_1.asnumpy()) | |||
| x_sum_1 = np.sum(x_1) | |||
| assert abs(output_sum_1 - x_sum_1)/x_sum_1 < 0.1 | |||
| mask_sum_1 = np.sum(mask_1.asnumpy()) | |||
| assert np.count_nonzero(mask_1.asnumpy()) == nonzero_count_1 | |||
| assert abs(mask_sum_1 - nonzero_count_1)/nonzero_count_1 < 0.1 | |||
| output_2, mask_2 = net(Tensor(x_2)) | |||
| elem_count_2 = x_2.size | |||
| nonzero_count_2 = np.count_nonzero(output_2.asnumpy()) | |||
| assert (elem_count_2 * (keep_prob - 0.1)) < nonzero_count_2 < (elem_count_2 * (keep_prob + 0.1)) | |||
| output_sum_2 = np.sum(output_2.asnumpy()) | |||
| x_sum_2 = np.sum(x_2) | |||
| assert abs(output_sum_2 - x_sum_2)/x_sum_2 < 0.1 | |||
| mask_sum_2 = np.sum(mask_2.asnumpy()) | |||
| assert np.count_nonzero(mask_2.asnumpy()) == nonzero_count_2 | |||
| assert abs(mask_sum_2 - nonzero_count_2)/nonzero_count_2 < 0.1 | |||