| @@ -36,6 +36,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"re_lu6_grad", "relu6_grad"}, | {"re_lu6_grad", "relu6_grad"}, | ||||
| {"re_lu", "relu"}, | {"re_lu", "relu"}, | ||||
| {"re_luv2", "relu_v2"}, | {"re_luv2", "relu_v2"}, | ||||
| {"p_re_lu", "prelu"}, | |||||
| {"p_re_lu_grad", "prelu_grad"}, | |||||
| {"tensor_add", "add"}, | {"tensor_add", "add"}, | ||||
| {"reduce_mean", "reduce_mean_d"}, | {"reduce_mean", "reduce_mean_d"}, | ||||
| {"reduce_max", "reduce_max_d"}, | {"reduce_max", "reduce_max_d"}, | ||||
| @@ -184,3 +184,5 @@ from .bn_training_update_v2 import _bn_training_update_v2_tbe | |||||
| from .square_sum_all import square_sum_all_op_info | from .square_sum_all import square_sum_all_op_info | ||||
| from .pack import _pack_tbe | from .pack import _pack_tbe | ||||
| from .unpack import _unpack_tbe | from .unpack import _unpack_tbe | ||||
| from .prelu import _prelu_tbe | |||||
| from .prelu_grad import _prelu_grad_tbe | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """PReLU op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| prelu_op_info = TBERegOp("PReLU") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("prelu.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("prelu") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "weight", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(prelu_op_info) | |||||
| def _prelu_tbe(): | |||||
| """PReLU TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """PReLUGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| prelu_grad_op_info = TBERegOp("PReLUGrad") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("prelu_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("prelu_grad") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "grads", False, "required", "all") \ | |||||
| .input(1, "features", False, "required", "all") \ | |||||
| .input(2, "weights", False, "required", "all") \ | |||||
| .output(0, "dx", False, "required", "all") \ | |||||
| .output(0, "da", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_Default, | |||||
| DataType.F32_NCHW, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(prelu_grad_op_info) | |||||
| def _prelu_grad_tbe(): | |||||
| """PReLUGrad TBE register""" | |||||
| return | |||||
| @@ -24,6 +24,7 @@ from mindspore.common.initializer import initializer | |||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | from mindspore.ops import prim_attr_register, PrimitiveWithInfer | ||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | from ....mindspore_test_framework.mindspore_test import mindspore_test | ||||
| @@ -456,6 +457,28 @@ class FlattenNet(nn.Cell): | |||||
| return self.flatten(x) | return self.flatten(x) | ||||
| class PReLUNet(nn.Cell): | |||||
| """ PReLUNet definition """ | |||||
| def __init__(self): | |||||
| super(PReLUNet, self).__init__() | |||||
| self.prelu = P.PReLU() | |||||
| self.w = Tensor(np.ones(3, np.float32)) | |||||
| def construct(self, x): | |||||
| return self.prelu(x, self.w) | |||||
| class PReLUGradNet(nn.Cell): | |||||
| """ PReLUGradNet definition """ | |||||
| def __init__(self): | |||||
| super(PReLUGradNet, self).__init__() | |||||
| self.prelu_grad = G.PReLUGrad() | |||||
| def construct(self, dout, x, w): | |||||
| return self.prelu_grad(dout, x, w) | |||||
| test_cases = [ | test_cases = [ | ||||
| ('SoftMaxGrad', { | ('SoftMaxGrad', { | ||||
| 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), | 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), | ||||
| @@ -545,6 +568,16 @@ test_cases = [ | |||||
| 'block': FlattenNet(), | 'block': FlattenNet(), | ||||
| 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))], | 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))], | ||||
| }), | }), | ||||
| ('PReLUNet', { | |||||
| 'block': PReLUNet(), | |||||
| 'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32))], | |||||
| }), | |||||
| ('PReLUGradNet', { | |||||
| 'block': PReLUGradNet(), | |||||
| 'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32)), | |||||
| Tensor(np.ones([1, 3, 4, 4], np.float32)), | |||||
| Tensor(np.ones(3, np.float32))], | |||||
| }), | |||||
| ] | ] | ||||
| test_cases_for_verify_exception = [ | test_cases_for_verify_exception = [ | ||||