From 2fc5ebd07771413d0544ba5de8ba0f6e5458e86f Mon Sep 17 00:00:00 2001 From: TFBunny Date: Wed, 27 Jan 2021 17:44:26 -0500 Subject: [PATCH] add dynamic shape support and testcases to GPU Dropout --- .../gpu/nn/dropout_gpu_kernel.h | 23 +++++---- mindspore/core/abstract/infer_functions.h | 2 + mindspore/core/abstract/prim_nn.cc | 19 ++++++- .../core/abstract/primitive_infer_map.cc | 1 + mindspore/ops/operations/nn_ops.py | 9 ++-- tests/st/ops/gpu/test_dropout.py | 49 ++++++++++++++++++- 6 files changed, 85 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h index 1fe3cff151..926e351eae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,14 +28,7 @@ namespace kernel { template class DropoutGpuFwdKernel : public GpuKernel { public: - DropoutGpuFwdKernel() - : cudnn_handle_(nullptr), - is_null_input_(false), - num_count_(0), - keep_prob_(0.0), - states_init_(false), - mask_generator_(nullptr) {} - + DropoutGpuFwdKernel() { ResetResource(); } ~DropoutGpuFwdKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -96,6 +89,18 @@ class DropoutGpuFwdKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + cudnn_handle_ = nullptr; + is_null_input_ = false; + num_count_ = 0; + keep_prob_ = 0.0; + states_init_ = false; + mask_generator_ = nullptr; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + protected: void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index a84a0d85a4..8703579a3b 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -85,6 +85,8 @@ AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index a92094f50d..6bf68eccae 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -573,6 +573,23 @@ AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const Primitiv return std::make_shared(args_list); } +AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + ShapeVector shape = x->shape()->shape(); + ShapeVector min_shape = x->shape()->min_shape(); + ShapeVector max_shape = x->shape()->max_shape(); + (void)CheckMinMaxShape(shape, &min_shape, &max_shape); + auto output_shape = + std::make_shared(x->element(), std::make_shared(shape, min_shape, max_shape)); + AbstractBasePtrList ret = {output_shape, output_shape}; + return std::make_shared(ret); +} + AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple and a tensor. diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 88b59947cf..81df673962 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -123,6 +123,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimBpropCut, {InferImplBpropCut, true}}, {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, + {prim::kPrimDropout, {InferImplDropout, true}}, {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 60f140b328..f8e4573acc 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6336,7 +6336,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): return var_dtype, accum_dtype, linear_dtype -class Dropout(PrimitiveWithInfer): +class Dropout(PrimitiveWithCheck): """ During training, randomly zeroes some of the elements of the input tensor with probability. @@ -6367,15 +6367,12 @@ class Dropout(PrimitiveWithInfer): self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name) self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) - def infer_shape(self, x_shape): + def check_shape(self, x_shape): validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name) - mask_shape = x_shape - return x_shape, mask_shape - def infer_dtype(self, x_dtype): + def check_dtype(self, x_dtype): valid_dtypes = (mstype.float16, mstype.float32) validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) - return x_dtype, x_dtype class Dropout3d(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_dropout.py b/tests/st/ops/gpu/test_dropout.py index 9a9f9d5b09..91382de4bd 100644 --- a/tests/st/ops/gpu/test_dropout.py +++ b/tests/st/ops/gpu/test_dropout.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,9 @@ import pytest import mindspore.nn as nn from mindspore import Tensor +import mindspore.context as context from mindspore.ops import operations as P - +from mindspore.ops.operations import _inner_ops as inner class Net(nn.Cell): def __init__(self, keep_prob): @@ -52,3 +53,47 @@ def test_dropout(): mask_sum = np.sum(mask_np) assert np.count_nonzero(mask_np) == nonzero_count assert abs(mask_sum - nonzero_count)/nonzero_count < 0.1 + + +class DropoutDynamic(nn.Cell): + def __init__(self, keep_prob): + super(DropoutDynamic, self).__init__() + self.test_dynamic = inner.GpuConvertToDynamicShape() + self.drop = P.Dropout(keep_prob) + + def construct(self, x): + x = self.test_dynamic(x) + return self.drop(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dropout_dynamic(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x_1 = np.ones([32, 16, 2, 5]).astype(np.float32) + x_2 = np.ones([32, 16, 2, 5, 6]).astype(np.float32) + keep_prob = 0.4 + net = DropoutDynamic(keep_prob) + + output_1, mask_1 = net(Tensor(x_1)) + elem_count_1 = x_1.size + nonzero_count_1 = np.count_nonzero(output_1.asnumpy()) + assert (elem_count_1 * (keep_prob - 0.1)) < nonzero_count_1 < (elem_count_1 * (keep_prob + 0.1)) + output_sum_1 = np.sum(output_1.asnumpy()) + x_sum_1 = np.sum(x_1) + assert abs(output_sum_1 - x_sum_1)/x_sum_1 < 0.1 + mask_sum_1 = np.sum(mask_1.asnumpy()) + assert np.count_nonzero(mask_1.asnumpy()) == nonzero_count_1 + assert abs(mask_sum_1 - nonzero_count_1)/nonzero_count_1 < 0.1 + + output_2, mask_2 = net(Tensor(x_2)) + elem_count_2 = x_2.size + nonzero_count_2 = np.count_nonzero(output_2.asnumpy()) + assert (elem_count_2 * (keep_prob - 0.1)) < nonzero_count_2 < (elem_count_2 * (keep_prob + 0.1)) + output_sum_2 = np.sum(output_2.asnumpy()) + x_sum_2 = np.sum(x_2) + assert abs(output_sum_2 - x_sum_2)/x_sum_2 < 0.1 + mask_sum_2 = np.sum(mask_2.asnumpy()) + assert np.count_nonzero(mask_2.asnumpy()) == nonzero_count_2 + assert abs(mask_sum_2 - nonzero_count_2)/nonzero_count_2 < 0.1