| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_BIAS_ADD_GPU_KERNEL_H | |||||
| #define MINDSPORE_BIAS_ADD_GPU_KERNEL_H | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_ | |||||
| #include <cuda_runtime_api.h> | #include <cuda_runtime_api.h> | ||||
| #include <string> | #include <string> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| @@ -30,13 +30,7 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class BiasAddGpuKernel : public GpuKernel { | class BiasAddGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| BiasAddGpuKernel() | |||||
| : cudnn_handle_(nullptr), | |||||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||||
| x_desc_(nullptr), | |||||
| b_desc_(nullptr), | |||||
| op_desc_(nullptr), | |||||
| is_null_input_(false) {} | |||||
| BiasAddGpuKernel() { ResetResource(); } | |||||
| ~BiasAddGpuKernel() override { DestroyResource(); } | ~BiasAddGpuKernel() override { DestroyResource(); } | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -117,6 +111,18 @@ class BiasAddGpuKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void ResetResource() noexcept override { | |||||
| cudnn_handle_ = nullptr; | |||||
| cudnn_data_type_ = CUDNN_DATA_FLOAT; | |||||
| x_desc_ = nullptr; | |||||
| b_desc_ = nullptr; | |||||
| op_desc_ = nullptr; | |||||
| is_null_input_ = false; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| void DestroyResource() noexcept override { | void DestroyResource() noexcept override { | ||||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyOpTensorDescriptor(op_desc_), | CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyOpTensorDescriptor(op_desc_), | ||||
| "cudnnDestroyTensorDescriptor failed"); | "cudnnDestroyTensorDescriptor failed"); | ||||
| @@ -136,6 +142,7 @@ class BiasAddGpuKernel : public GpuKernel { | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateOpTensorDescriptor(&op_desc_), | CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateOpTensorDescriptor(&op_desc_), | ||||
| "cudnnCreateOpTensorDescriptor failed"); | "cudnnCreateOpTensorDescriptor failed"); | ||||
| } | } | ||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| size_t x_size, b_size; | size_t x_size, b_size; | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size), | CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size), | ||||
| @@ -161,4 +168,4 @@ class BiasAddGpuKernel : public GpuKernel { | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_BIAS_ADD_GPU_KERNEL_H | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_ | |||||
| @@ -63,6 +63,8 @@ AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const Pr | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -470,6 +470,41 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P | |||||
| return args_spec_list[2]->Broaden(); | return args_spec_list[2]->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto bias = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(x); | |||||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||||
| ShapeVector x_shape = x->shape()->shape(); | |||||
| MS_EXCEPTION_IF_NULL(bias); | |||||
| MS_EXCEPTION_IF_NULL(bias->shape()); | |||||
| ShapeVector bias_shape = bias->shape()->shape(); | |||||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||||
| std::set<std::string> available_data_format{"NCHW", "NHWC"}; | |||||
| auto data_format_ptr = primitive->GetAttr("data_format"); | |||||
| std::string data_format = "NCHW"; | |||||
| if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) { | |||||
| data_format = data_format_ptr->cast<StringImmPtr>()->value(); | |||||
| } | |||||
| if (available_data_format.find(data_format) == available_data_format.end()) { | |||||
| MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ", use NCHW or NHWC."; | |||||
| } | |||||
| auto x_channel = data_format == "NHWC" ? x_shape[x_shape.size() - 1] : x_shape[1]; | |||||
| // Additional check for dynamic shape | |||||
| // Last infer will be real shape values | |||||
| bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||||
| if (x_not_dyn && bias_shape[0] != x_channel) { | |||||
| MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format | |||||
| << ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << "."; | |||||
| } | |||||
| (void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); | |||||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape)); | |||||
| } | |||||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: at least one tensor(y_backprop) | // Inputs: at least one tensor(y_backprop) | ||||
| @@ -114,6 +114,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | {prim::kPrimConv2D, {InferImplConv2D, true}}, | ||||
| {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, | {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, | ||||
| {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, | {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, | ||||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, | |||||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | ||||
| {prim::kPrimRelu, {InferImplRelu, true}}, | {prim::kPrimRelu, {InferImplRelu, true}}, | ||||
| {prim::kPrimZerosLike, {InferImplZerosLike, true}}, | {prim::kPrimZerosLike, {InferImplZerosLike, true}}, | ||||
| @@ -1887,19 +1887,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class BiasAdd(PrimitiveWithInfer): | |||||
| class BiasAdd(PrimitiveWithCheck): | |||||
| r""" | r""" | ||||
| Returns sum of input and bias tensor. | Returns sum of input and bias tensor. | ||||
| Adds the 1-D bias tensor to the input tensor, and broadcasts the shape on all axis | Adds the 1-D bias tensor to the input tensor, and broadcasts the shape on all axis | ||||
| except for the channel axis. | except for the channel axis. | ||||
| Args: | |||||
| data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW', | |||||
| default is 'NCHW'. | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions. | - **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions. | ||||
| - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. | |||||
| - **data_format** (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\ | |||||
| default is 'NCHW'. | |||||
| The shape of `bias` must be the same as `input_x` in the second dimension. | |||||
| - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of | |||||
| `bias` must be the same as `input_x`'s channel dimension. | |||||
| Outputs: | Outputs: | ||||
| Tensor, with the same shape and type as `input_x`. | Tensor, with the same shape and type as `input_x`. | ||||
| @@ -1924,17 +1926,16 @@ class BiasAdd(PrimitiveWithInfer): | |||||
| raise ValueError("NHWC format only support in GPU target.") | raise ValueError("NHWC format only support in GPU target.") | ||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| def infer_shape(self, x_shape, b_shape): | |||||
| def check_shape(self, x_shape, b_shape): | |||||
| validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | ||||
| validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) | validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) | ||||
| x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1] | x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1] | ||||
| validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_channel, Rel.EQ, self.name) | |||||
| return x_shape | |||||
| if np.all(np.array(x_shape) != -1): | |||||
| validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name) | |||||
| def infer_dtype(self, x_type, b_type): | |||||
| def check_dtype(self, x_type, b_type): | |||||
| args = {"input_x": x_type, "bias": b_type} | args = {"input_x": x_type, "bias": b_type} | ||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | ||||
| return x_type | |||||
| class TopK(PrimitiveWithInfer): | class TopK(PrimitiveWithInfer): | ||||
| @@ -23,7 +23,7 @@ from mindspore.common.parameter import ParameterTuple | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.composite import GradOperation | from mindspore.ops.composite import GradOperation | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| class BiasAdd(nn.Cell): | class BiasAdd(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -442,3 +442,66 @@ def test_biasadd_4d(): | |||||
| error = np.ones(shape=[3]) * 1.0e-6 | error = np.ones(shape=[3]) * 1.0e-6 | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| assert np.all(-diff < error) | assert np.all(-diff < error) | ||||
| class BiasAddDynamic(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BiasAddDynamic, self).__init__() | |||||
| self.ba = P.BiasAdd() | |||||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||||
| def construct(self, x, b): | |||||
| x = self.test_dynamic(x) | |||||
| output = self.ba(x, b) | |||||
| return output | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_bias_add_dynamic_two_inputs(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| net = BiasAddDynamic() | |||||
| x_1 = Tensor(np.array([[0.1, 0.2, 0.3, 0.4], | |||||
| [0.5, 0.6, 0.7, 0.8], | |||||
| [0.9, 1.0, 1.1, 1.2]]).astype(np.float32)) | |||||
| b_1 = Tensor(np.array([0.1, 0.2, 0.3, 0.4]).astype(np.float32)) | |||||
| expect_1 = np.array([[0.2, 0.4, 0.6, 0.8], | |||||
| [0.6, 0.8, 1.0, 1.2], | |||||
| [1.0, 1.2, 1.4, 1.6]]) | |||||
| error_1 = np.ones(shape=[3, 4]) * 1.0e-6 | |||||
| result_1 = net(x_1, b_1) | |||||
| diff_1 = result_1.asnumpy() - expect_1 | |||||
| assert np.all(diff_1 < error_1) | |||||
| assert np.all(-diff_1 < error_1) | |||||
| x_2 = Tensor(np.array([[[1, 2, 3, 4, 5, 6, 7, 8], | |||||
| [9, 10, 11, 12, 13, 14, 15, 16], | |||||
| [17, 18, 19, 20, 21, 22, 23, 24], | |||||
| [25, 26, 27, 28, 29, 30, 31, 32]], | |||||
| [[33, 34, 35, 36, 37, 38, 39, 40], | |||||
| [41, 42, 43, 44, 45, 46, 47, 48], | |||||
| [49, 50, 51, 52, 53, 54, 55, 56], | |||||
| [57, 58, 59, 60, 61, 62, 63, 64]], | |||||
| [[65, 66, 67, 68, 69, 70, 71, 72], | |||||
| [73, 74, 75, 76, 77, 78, 79, 80], | |||||
| [81, 82, 83, 84, 85, 86, 87, 88], | |||||
| [89, 90, 91, 92, 93, 94, 95, 96]]]).astype(np.float32)) | |||||
| b_2 = Tensor(np.array([1, 2, 3, 4]).astype(np.float32)) | |||||
| expect_2 = np.array([[[2, 3, 4, 5, 6, 7, 8, 9], | |||||
| [11, 12, 13, 14, 15, 16, 17, 18], | |||||
| [20, 21, 22, 23, 24, 25, 26, 27], | |||||
| [29, 30, 31, 32, 33, 34, 35, 36]], | |||||
| [[34, 35, 36, 37, 38, 39, 40, 41], | |||||
| [43, 44, 45, 46, 47, 48, 49, 50], | |||||
| [52, 53, 54, 55, 56, 57, 58, 59], | |||||
| [61, 62, 63, 64, 65, 66, 67, 68]], | |||||
| [[66, 67, 68, 69, 70, 71, 72, 73], | |||||
| [75, 76, 77, 78, 79, 80, 81, 82], | |||||
| [84, 85, 86, 87, 88, 89, 90, 91], | |||||
| [93, 94, 95, 96, 97, 98, 99, 100]]]) | |||||
| error_2 = np.ones(shape=[3, 4, 8]) * 1.0e-6 | |||||
| result_2 = net(x_2, b_2) | |||||
| diff_2 = result_2.asnumpy() - expect_2 | |||||
| assert np.all(diff_2 < error_2) | |||||
| assert np.all(-diff_2 < error_2) | |||||