| @@ -98,6 +98,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"n_ms_with_mask", "nms_with_mask"}, | |||
| {"square_sum_all", "square_sum_all"}, | |||
| {"cum_sum", "cumsum_d"}, | |||
| {"range", "range_d"}, | |||
| {"inv_grad", "inv_grad"}, | |||
| {"apply_rms_prop", "apply_rms_prop_d"}, | |||
| {"cum_prod", "cumprod_d"}, | |||
| @@ -244,3 +244,7 @@ from .confusion_matrix import _confusion_matrix_tbe | |||
| from .broadcast_to import _broadcast_to_tbe | |||
| from .strided_read import _strided_read_tbe | |||
| from .strided_write import _strided_write_tbe | |||
| from .range import _range_tbe | |||
| from .fused_mul_add_n_l2loss import _fused_mul_add_n_l2loss_tbe | |||
| from .fused_mul_apply_momentum_extern import _fused_mul_apply_momentum_extern_tbe | |||
| from .lamb_next_right import _lamb_next_right_tbe | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FusedMulAddNL2loss op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fused_mul_add_n_l2loss_op_info = TBERegOp("FusedMulAddNL2loss") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fused_mul_addn_l2loss.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fused_mul_addn_l2loss") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "required", "all") \ | |||
| .output(0, "y1", False, "required", "all") \ | |||
| .output(1, "y2", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, | |||
| DataType.F16_5HD, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, | |||
| DataType.F16_C1HWNCoC0, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||
| DataType.F16_FracZ, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, | |||
| DataType.F32_5HD, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, | |||
| DataType.F32_C1HWNCoC0, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, | |||
| DataType.F32_FracZ, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(fused_mul_add_n_l2loss_op_info) | |||
| def _fused_mul_add_n_l2loss_tbe(): | |||
| """FusedMulAddNL2loss TBE register""" | |||
| return | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FusedMulApplyMomentumExtern op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fused_mul_apply_momentum_extern_op_info = TBERegOp("FusedMulApplyMomentumExtern") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fused_mul_apply_momentum_extern.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fused_mul_apply_momentum_extern") \ | |||
| .partial_flag(True) \ | |||
| .attr("use_nesterov", "optional", "bool", "true,false", "false") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "accum", False, "required", "all") \ | |||
| .input(2, "lr", False, "required", "all") \ | |||
| .input(3, "x1", False, "required", "all") \ | |||
| .input(4, "momentum", False, "required", "all") \ | |||
| .input(5, "x2", False, "required", "all") \ | |||
| .input(6, "var_copy", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "var_copy", False, "required", "all") \ | |||
| .output(2, "accum", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_5HD, DataType.F32_5HD, | |||
| DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F32_C1HWNCoC0, | |||
| DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_FracZ, DataType.F32_FracZ, | |||
| DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F16_5HD, DataType.F32_5HD, | |||
| DataType.F16_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F16_C1HWNCoC0, DataType.F32_C1HWNCoC0, | |||
| DataType.F16_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, | |||
| DataType.F16_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F16_FracZ, DataType.F32_FracZ, | |||
| DataType.F16_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register(fused_mul_apply_momentum_extern_op_info) | |||
| def _fused_mul_apply_momentum_extern_tbe(): | |||
| """FusedMulApplyMomentumExtern TBE register""" | |||
| return | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LambNextRight op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lamb_next_right_op_info = TBERegOp("LambNextRight") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lamb_next_right.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lamb_next_right") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input_square", False, "required", "all") \ | |||
| .input(1, "input_mul2", False, "required", "all") \ | |||
| .input(2, "mul2_x", False, "required", "all") \ | |||
| .input(3, "mul3_x", False, "required", "all") \ | |||
| .input(4, "truediv1_recip", False, "required", "all") \ | |||
| .input(5, "add2_y", False, "required", "all") \ | |||
| .output(0, "y1", False, "required", "all") \ | |||
| .output(1, "y2", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(lamb_next_right_op_info) | |||
| def _lamb_next_right_tbe(): | |||
| """LambNextRight TBE register""" | |||
| return | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Range op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| range_op_info = TBERegOp("Range") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("range_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("range_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("start", "required", "float", "all") \ | |||
| .attr("limit", "required", "float", "all") \ | |||
| .attr("delta", "required", "float", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(range_op_info) | |||
| def _range_tbe(): | |||
| """Range TBE register""" | |||
| return | |||
| @@ -292,7 +292,8 @@ __all__ = [ | |||
| "Atanh", | |||
| "BasicLSTMCell", | |||
| "ConfusionMatrix", | |||
| "BroadcastTo" | |||
| "BroadcastTo", | |||
| "Range" | |||
| ] | |||
| __all__.extend(_quant_ops.__all__) | |||
| @@ -295,6 +295,15 @@ class ConfusionMatrixNet(Cell): | |||
| return self.confusion_matrix(x, y) | |||
| class RangeNet(Cell): | |||
| def __init__(self): | |||
| super(RangeNet, self).__init__() | |||
| self.range_ops = P.Range(1.0, 8.0, 2.0) | |||
| def construct(self, x): | |||
| return self.range_ops(x) | |||
| test_case_array_ops = [ | |||
| ('CustNet1', { | |||
| 'block': CustNet1(), | |||
| @@ -338,6 +347,9 @@ test_case_array_ops = [ | |||
| ('ConfusionMatrixNet', { | |||
| 'block': ConfusionMatrixNet(), | |||
| 'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}), | |||
| ('RangeNet', { | |||
| 'block': RangeNet(), | |||
| 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}), | |||
| ] | |||
| test_case_lists = [test_case_array_ops] | |||