| @@ -36,6 +36,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"re_lu6_grad", "relu6_grad"}, | |||
| {"re_lu", "relu"}, | |||
| {"re_luv2", "relu_v2"}, | |||
| {"p_re_lu", "prelu"}, | |||
| {"p_re_lu_grad", "prelu_grad"}, | |||
| {"tensor_add", "add"}, | |||
| {"reduce_mean", "reduce_mean_d"}, | |||
| {"reduce_max", "reduce_max_d"}, | |||
| @@ -184,3 +184,5 @@ from .bn_training_update_v2 import _bn_training_update_v2_tbe | |||
| from .square_sum_all import square_sum_all_op_info | |||
| from .pack import _pack_tbe | |||
| from .unpack import _unpack_tbe | |||
| from .prelu import _prelu_tbe | |||
| from .prelu_grad import _prelu_grad_tbe | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """PReLU op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| prelu_op_info = TBERegOp("PReLU") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("prelu.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("prelu") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "weight", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(prelu_op_info) | |||
| def _prelu_tbe(): | |||
| """PReLU TBE register""" | |||
| return | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """PReLUGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| prelu_grad_op_info = TBERegOp("PReLUGrad") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("prelu_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("prelu_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "grads", False, "required", "all") \ | |||
| .input(1, "features", False, "required", "all") \ | |||
| .input(2, "weights", False, "required", "all") \ | |||
| .output(0, "dx", False, "required", "all") \ | |||
| .output(0, "da", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_Default, | |||
| DataType.F32_NCHW, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(prelu_grad_op_info) | |||
| def _prelu_grad_tbe(): | |||
| """PReLUGrad TBE register""" | |||
| return | |||
| @@ -24,6 +24,7 @@ from mindspore.common.initializer import initializer | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| @@ -456,6 +457,28 @@ class FlattenNet(nn.Cell): | |||
| return self.flatten(x) | |||
| class PReLUNet(nn.Cell): | |||
| """ PReLUNet definition """ | |||
| def __init__(self): | |||
| super(PReLUNet, self).__init__() | |||
| self.prelu = P.PReLU() | |||
| self.w = Tensor(np.ones(3, np.float32)) | |||
| def construct(self, x): | |||
| return self.prelu(x, self.w) | |||
| class PReLUGradNet(nn.Cell): | |||
| """ PReLUGradNet definition """ | |||
| def __init__(self): | |||
| super(PReLUGradNet, self).__init__() | |||
| self.prelu_grad = G.PReLUGrad() | |||
| def construct(self, dout, x, w): | |||
| return self.prelu_grad(dout, x, w) | |||
| test_cases = [ | |||
| ('SoftMaxGrad', { | |||
| 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), | |||
| @@ -545,6 +568,16 @@ test_cases = [ | |||
| 'block': FlattenNet(), | |||
| 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))], | |||
| }), | |||
| ('PReLUNet', { | |||
| 'block': PReLUNet(), | |||
| 'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32))], | |||
| }), | |||
| ('PReLUGradNet', { | |||
| 'block': PReLUGradNet(), | |||
| 'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32)), | |||
| Tensor(np.ones([1, 3, 4, 4], np.float32)), | |||
| Tensor(np.ones(3, np.float32))], | |||
| }), | |||
| ] | |||
| test_cases_for_verify_exception = [ | |||