Merge pull request !293 from zjun/Modify_all_tbe_optags/v0.2.0-alpha
| @@ -14,71 +14,28 @@ | |||
| # ============================================================================ | |||
| """Add op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| add_op_info = TBERegOp("Add") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("add.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("add") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Add", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "add.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "add", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", | |||
| "float", "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(add_op_info) | |||
| def _add_tbe(): | |||
| """Add TBE register""" | |||
| return | |||
| @@ -14,61 +14,33 @@ | |||
| # ============================================================================ | |||
| """AddN op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| add_n_op_info = TBERegOp("AddN") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("add_n.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("add_n") \ | |||
| .partial_flag(True) \ | |||
| .attr("n", "required", "int", "all") \ | |||
| .input(0, "x", False, "dynamic", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "AddN", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "add_n.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "add_n", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "n", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float","int32","int32","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ","DefaultFormat","NC1HWC0","FracZ" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "dynamic", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float","int32","int32","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ","DefaultFormat","NC1HWC0","FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(add_n_op_info) | |||
| def _add_n_tbe(): | |||
| """AddN TBE register""" | |||
| return | |||
| @@ -14,214 +14,66 @@ | |||
| # ============================================================================ | |||
| """ApplyAdam op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| apply_adam_op_info = TBERegOp("Adam") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("apply_adam.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("apply_adam") \ | |||
| .partial_flag(True) \ | |||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | |||
| .attr("use_nesterov", "optional", "bool", "true,false", "false") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "m", False, "required", "all") \ | |||
| .input(2, "v", False, "required", "all") \ | |||
| .input(3, "beta1_power", False, "required", "all") \ | |||
| .input(4, "beta2_power", False, "required", "all") \ | |||
| .input(5, "lr", False, "required", "all") \ | |||
| .input(6, "beta1", False, "required", "all") \ | |||
| .input(7, "beta2", False, "required", "all") \ | |||
| .input(8, "epsilon", False, "required", "all") \ | |||
| .input(9, "grad", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "m", False, "required", "all") \ | |||
| .output(2, "v", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, | |||
| DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Adam", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "apply_adam.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "apply_adam", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "use_locking", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "true,false", | |||
| "default_value":"false" | |||
| }, | |||
| { | |||
| "name": "use_nesterov", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "true,false", | |||
| "default_value":"false" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "var", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "m", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "v", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "beta1_power", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "beta2_power", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 5, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "lr", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 6, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "beta1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 7, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "beta2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 8, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "epsilon", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 9, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "grad", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "var", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "m", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "v", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(apply_adam_op_info) | |||
| def _apply_adam_tbe(): | |||
| """ApplyAdam TBE register""" | |||
| return | |||
| @@ -14,112 +14,42 @@ | |||
| # ============================================================================ | |||
| """ApplyMomentum op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| apply_momentum_op_info = TBERegOp("ApplyMomentum") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("apply_momentum.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("apply_momentum") \ | |||
| .partial_flag(True) \ | |||
| .attr("use_nesterov", "optional", "bool", "true,false", "false") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "accum", False, "required", "all") \ | |||
| .input(2, "lr", False, "required", "all") \ | |||
| .input(3, "grad", False, "required", "all") \ | |||
| .input(4, "momentum", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | |||
| DataType.F16_Default, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_Default, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | |||
| DataType.F16_Default, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | |||
| DataType.F32_Default, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_Default, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | |||
| DataType.F32_Default, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ApplyMomentum", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "apply_momentum.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "apply_momentum", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "use_nesterov", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "true,false", | |||
| "default_value":"false" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0" | |||
| ], | |||
| "name": "var", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0" | |||
| ], | |||
| "name": "accum", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "lr", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0" | |||
| ], | |||
| "name": "grad", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "momentum", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0" | |||
| ], | |||
| "name": "var", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(apply_momentum_op_info) | |||
| def _apply_momentum_tbe(): | |||
| """ApplyMomentum TBE register""" | |||
| return | |||
| @@ -14,70 +14,25 @@ | |||
| # ============================================================================ | |||
| """ArgMaxWithValue op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| arg_max_with_value_op_info = TBERegOp("ArgMaxWithValue") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("arg_max_with_value.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("arg_max_with_value") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "required", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "indice", False, "required", "all") \ | |||
| .output(1, "values", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ArgMaxWithValue", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "arg_max_with_value.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "arg_max_with_value", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "indice", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "values", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(arg_max_with_value_op_info) | |||
| def _arg_max_with_value_tbe(): | |||
| """ArgMaxWithValue TBE register""" | |||
| return | |||
| @@ -14,70 +14,25 @@ | |||
| # ============================================================================ | |||
| """ArgMinWithValue op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| arg_min_with_value_op_info = TBERegOp("ArgMaxWithValue") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("arg_min_with_value.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("arg_min_with_value") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "required", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "indice", False, "required", "all") \ | |||
| .output(1, "values", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ArgMinWithValue", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "arg_min_with_value.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "arg_min_with_value", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "indice", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "values", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(arg_min_with_value_op_info) | |||
| def _arg_min_with_value_tbe(): | |||
| """ArgMinWithValue TBE register""" | |||
| return | |||
| @@ -14,93 +14,43 @@ | |||
| # ============================================================================ | |||
| """Assign op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| assign_op_info = TBERegOp("Assign") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("assign.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("assign") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "resource", False, "required", "all") \ | |||
| .input(1, "value", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I16_5HD, DataType.I16_5HD, DataType.I16_5HD) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U16_5HD, DataType.U16_5HD, DataType.U16_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U32_5HD, DataType.U32_5HD, DataType.U32_5HD) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.U64_5HD, DataType.U64_5HD, DataType.U64_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Assign", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "assign.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "assign", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8", | |||
| "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int16", "int16", "int16", | |||
| "int16", "uint16", "uint16", "uint16", "uint16", "int64", "int64", "int64", "int64", | |||
| "uint64", "uint64", "uint64", "uint64", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ" | |||
| ], | |||
| "name": "resource", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32", | |||
| "int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8", "int8", "int8", "int8", "uint8", | |||
| "uint8", "uint8", "uint8", "int16", "int16", "int16", "int16", "uint16", "uint16", "uint16", | |||
| "uint16", "int64", "int64", "int64", "int64", "uint64", "uint64", "uint64", "uint64", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ" | |||
| ], | |||
| "name": "value", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8", "int8", "int8", | |||
| "int8", "uint8", "uint8", "uint8", "uint8", "int16", "int16", "int16", "int16", "uint16", | |||
| "uint16", "uint16", "uint16", "int64", "int64", "int64", "int64", | |||
| "uint64", "uint64", "uint64", "uint64", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(assign_op_info) | |||
| def _assign_tbe(): | |||
| """Assign TBE register""" | |||
| return | |||
| @@ -14,80 +14,34 @@ | |||
| # ============================================================================ | |||
| """AssignAdd op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| assign_add_op_info = TBERegOp("AssignAdd") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("assignadd.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("assignadd") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "ref", False, "required", "all") \ | |||
| .input(1, "value", False, "required", "all") \ | |||
| .output(0, "output_ref", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "AssignAdd", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "assignadd.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "assignadd", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32", | |||
| "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64", | |||
| "int64", "int64", "int64" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "ref", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32", | |||
| "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64", | |||
| "int64", "int64", "int64" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "value", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32", | |||
| "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64", | |||
| "int64", "int64", "int64" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output_ref", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(assign_add_op_info) | |||
| def _assign_add_tbe(): | |||
| """AssignAdd TBE register""" | |||
| return | |||
| @@ -14,65 +14,27 @@ | |||
| # ============================================================================ | |||
| """AssignSub op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| assign_sub_op_info = TBERegOp("AssignSub") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("assign_sub.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("assign_sub") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "value", False, "required", "all") \ | |||
| .output(0, "output_ref", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "AssignSub", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "assign_sub.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "assign_sub", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int32", "int8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "var", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float", "int32", "int8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "value", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int32", "int8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "out_ref", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(assign_sub_op_info) | |||
| def _assign_sub_tbe(): | |||
| """AssignSub TBE register""" | |||
| return | |||
| @@ -14,31 +14,20 @@ | |||
| # ============================================================================ | |||
| """AtomicAddrClean op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp | |||
| atomic_addr_clean_op_info = TBERegOp("AtomicAddrClean") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("atomic_addr_clean.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("atomic_addr_clean") \ | |||
| .partial_flag(True) \ | |||
| .attr("automic_add_mem_size", "required", "listInt", "all") \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "AtomicAddrClean", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "atomic_addr_clean.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "atomic_addr_clean", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "automic_add_mem_size", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| ], | |||
| "outputs": [ | |||
| ] | |||
| }""") | |||
| @op_info_register(atomic_addr_clean_op_info) | |||
| def _atomic_addr_clean_tbe(): | |||
| """AtomicAddrClean TBE register""" | |||
| return | |||
| @@ -14,88 +14,29 @@ | |||
| # ============================================================================ | |||
| """BatchMatMul op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| batch_matmul_op_info = TBERegOp("BatchMatMul") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batch_matmul.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batch_matmul") \ | |||
| .attr("transpose_x1", "required", "bool", "all") \ | |||
| .attr("transpose_x2", "required", "bool", "all") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "bias", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BatchMatMul", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "batch_matmul.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "batch_matmul", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "transpose_x1", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "transpose_x2", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","int32","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","int32","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","int32","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "bias", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","int32","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(batch_matmul_op_info) | |||
| def _batch_matmul_tbe(): | |||
| """BatchMatMul TBE register""" | |||
| return | |||
| @@ -14,174 +14,45 @@ | |||
| # ============================================================================ | |||
| """BatchNorm op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| batch_norm_op_info = TBERegOp("BatchNorm") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batch_norm.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batch_norm") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "scale", False, "required", "all") \ | |||
| .input(2, "offset", False, "required", "all") \ | |||
| .input(3, "mean", False, "optional", "all") \ | |||
| .input(4, "variance", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "batch_mean", False, "required", "all") \ | |||
| .output(2, "batch_variance", False, "required", "all") \ | |||
| .output(3, "reserve_space_1", False, "optional", "all") \ | |||
| .output(4, "reserve_space_2", False, "optional", "all") \ | |||
| .output(5, "reserve_space_3", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BatchNorm", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "batch_norm.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "batch_norm", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "epsilon", | |||
| "param_type": "required", | |||
| "type": "float", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "data_format", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "is_training", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0", "DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "scale", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "offset", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "mean", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "variance", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0", "DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "batch_mean", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "batch_variance", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_1", | |||
| "param_type": "optional" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_2", | |||
| "param_type": "optional" | |||
| }, | |||
| { | |||
| "index": 5, | |||
| "dtype": [ | |||
| "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_3", | |||
| "param_type": "optional" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(batch_norm_op_info) | |||
| def _batch_norm_tbe(): | |||
| """BatchNorm TBE register""" | |||
| return | |||
| @@ -14,181 +14,45 @@ | |||
| # ============================================================================ | |||
| """BatchNormGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| batch_norm_grad_op_info = TBERegOp("BatchNormGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batchnormgrad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batchnormgrad") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .input(0, "y_backprop", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "scale", False, "required", "all") \ | |||
| .input(3, "reserve_space_1", False, "required", "all") \ | |||
| .input(4, "reserve_space_2", False, "required", "all") \ | |||
| .input(5, "reserve_space_3", False, "required", "all") \ | |||
| .output(0, "x_backprop", False, "required", "all") \ | |||
| .output(1, "scale_backprop", False, "required", "all") \ | |||
| .output(2, "offset_backprop", False, "required", "all") \ | |||
| .output(3, "reserve_space_4", False, "optional", "all") \ | |||
| .output(4, "reserve_space_5", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BatchNormGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "batchnormgrad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "batchnormgrad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "epsilon", | |||
| "param_type": "optional", | |||
| "type": "float", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "data_format", | |||
| "param_type": "optional", | |||
| "type": "str", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "is_training", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y_backprop", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "scale", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 5, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_3", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x_backprop", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "scale_backprop", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "offset_backprop", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_4", | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "reserve_space_5", | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(batch_norm_grad_op_info) | |||
| def _batch_norm_grad_tbe(): | |||
| """BatchNormGrad TBE register""" | |||
| return | |||
| @@ -14,70 +14,26 @@ | |||
| # ============================================================================ | |||
| """BiasAdd op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bias_add_grad_op_info = TBERegOp("BiasAdd") \ | |||
| .fusion_type("COMMREDUCE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bias_add.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bias_add") \ | |||
| .partial_flag(True) \ | |||
| .attr("data_format", "required", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "bias", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BiasAdd", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "COMMREDUCE", | |||
| "async_flag": false, | |||
| "binfile_name": "bias_add.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "bias_add", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "data_format", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "int32", "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "bias", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(bias_add_grad_op_info) | |||
| def _bias_add_tbe(): | |||
| """BiasAdd TBE register""" | |||
| return | |||
| @@ -14,57 +14,26 @@ | |||
| # ============================================================================ | |||
| """BiasAddGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bias_add_grad_op_info = TBERegOp("BiasAddGrad") \ | |||
| .fusion_type("COMMREDUCE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("biasaddgrad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("biasaddgrad") \ | |||
| .partial_flag(True) \ | |||
| .attr("data_format", "required", "str", "all") \ | |||
| .input(0, "output_backprop", False, "required", "all") \ | |||
| .output(0, "output", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BiasAddGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "COMMREDUCE", | |||
| "async_flag": false, | |||
| "binfile_name": "biasaddgrad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "biasaddgrad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "data_format", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","FRACTAL_NZ","DefaultFormat" | |||
| ], | |||
| "name": "out_backprop", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "output", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(bias_add_grad_op_info) | |||
| def _bias_add_grad_tbe(): | |||
| """BiasAddGrad TBE register""" | |||
| return | |||
| @@ -14,60 +14,24 @@ | |||
| # ============================================================================ | |||
| """BatchNorm op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bn_training_reduce.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bn_training_reduce") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "sum", False, "required", "all") \ | |||
| .output(1, "square_sum", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BNTrainingReduce", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "bn_training_reduce.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "bn_training_reduce", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "NC1HWC0" | |||
| ], | |||
| "name": "sum", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "NC1HWC0" | |||
| ], | |||
| "name": "square_sum", | |||
| "param_type": "required" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(bn_training_reduce_op_info) | |||
| def _bn_training_reduce_tbe(): | |||
| """BNTrainingReduce TBE register""" | |||
| return | |||
| @@ -14,134 +14,32 @@ | |||
| # ============================================================================ | |||
| """BatchNormGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bn_training_reduce_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bn_training_reduce_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .input(0, "grads", False, "required", "all") \ | |||
| .input(1, "x_norm", False, "required", "all") \ | |||
| .input(2, "diff_scale", False, "required", "all") \ | |||
| .input(3, "diff_offset", False, "required", "all") \ | |||
| .input(4, "scale", False, "required", "all") \ | |||
| .input(5, "batch_mean", False, "required", "all") \ | |||
| .input(6, "batch_variance", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BNTrainingReduceGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "bn_training_reduce_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "bn_training_reduce_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "epsilon", | |||
| "param_type": "optional", | |||
| "type": "float", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "grads", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "x_norm", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "diff_scale", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "diff_offset", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "scale", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 5, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "batch_mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 6, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "batch_variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(bn_training_reduce_grad_op_info) | |||
| def _bn_training_reduce_grad_tbe(): | |||
| """BNTrainingReduceGrad TBE register""" | |||
| return | |||
| @@ -14,200 +14,40 @@ | |||
| # ============================================================================ | |||
| """BatchNormGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bn_training_update.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bn_training_update") \ | |||
| .partial_flag(True) \ | |||
| .attr("factor", "optional", "float", "all") \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("isRef", "optional", "bool", "all", "true") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "sum", False, "required", "all") \ | |||
| .input(2, "square_sum", False, "required", "all") \ | |||
| .input(3, "scale", False, "required", "all") \ | |||
| .input(4, "offset", False, "required", "all") \ | |||
| .input(5, "mean", False, "required", "all") \ | |||
| .input(6, "variance", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "mean", False, "required", "all") \ | |||
| .output(2, "variance", False, "required", "all") \ | |||
| .output(3, "batch_mean", False, "required", "all") \ | |||
| .output(4, "batch_variance", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BNTrainingUpdate", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "bn_training_update.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "bn_training_update", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "factor", | |||
| "param_type": "optional", | |||
| "type": "float", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "epsilon", | |||
| "param_type": "optional", | |||
| "type": "float", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "isRef", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "default_value":"true", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "sum", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "square_sum", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "scale", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "offset", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 5, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 6, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "batch_mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "batch_variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(bn_training_update_op_info) | |||
| def _bn_training_update_tbe(): | |||
| """BNTrainingUpdate TBE register""" | |||
| return | |||
| @@ -14,109 +14,30 @@ | |||
| # ============================================================================ | |||
| """BatchNormGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bn_training_update_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bn_training_update_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .input(0, "grads", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "batch_mean", False, "required", "all") \ | |||
| .input(3, "batch_variance", False, "required", "all") \ | |||
| .output(0, "diff_scale", False, "required", "all") \ | |||
| .output(1, "diff_offset", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "BNTrainingUpdateGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "bn_training_update_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "bn_training_update_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "epsilon", | |||
| "param_type": "optional", | |||
| "type": "float", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "grads", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "batch_mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "batch_variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "diff_scale", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "diff_offset", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(bn_training_update_grad_op_info) | |||
| def _bn_training_update_grad_tbe(): | |||
| """BNTrainingUpdateGrad TBE register""" | |||
| return | |||
| @@ -14,69 +14,42 @@ | |||
| # ============================================================================ | |||
| """Cast op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| cast_op_info = TBERegOp("Cast") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("cast.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("cast") \ | |||
| .partial_flag(True) \ | |||
| .attr("dst_type", "required", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Cast", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "cast.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "cast", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "dst_type", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "int8", "int8", "int8", "uint8", "uint8", "uint8", | |||
| "bool", "bool", "bool", "bool", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float", "int32", "float16", "int32", | |||
| "float16", "float", "int8", "uint8", "bool", | |||
| "float16", "float", "int32", "float16", "float", "int32", | |||
| "float16", "float", "int32", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(cast_op_info) | |||
| def _cast_tbe(): | |||
| """Cast TBE register""" | |||
| return | |||
| @@ -14,90 +14,28 @@ | |||
| # ============================================================================ | |||
| """ClipByNormNoDivSum op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| clip_by_norm_no_div_sum_op_info = TBERegOp("ClipByNormNoDivSum") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("clip_by_norm_no_div_sum.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("clip_by_norm_no_div_sum") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input_x", False, "required", "all") \ | |||
| .input(1, "input1", False, "required", "all") \ | |||
| .input(2, "input2", False, "required", "all") \ | |||
| .input(3, "input3", False, "required", "all") \ | |||
| .output(0, "output_y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ClipByNormNoDivSum", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "clip_by_norm_no_div_sum.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "clip_by_norm_no_div_sum", | |||
| "partial_flag": true, | |||
| "attr":[ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "input_x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "input1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16", "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "input2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16", "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "input3", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output_y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(clip_by_norm_no_div_sum_op_info) | |||
| def _clip_by_norm_no_div_sum_tbe(): | |||
| """ClipByNormNoDivSum TBE register""" | |||
| return | |||
| @@ -14,85 +14,30 @@ | |||
| # ============================================================================ | |||
| """ClipByValue op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| clip_by_value_op_info = TBERegOp("ClipByValue") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("clip_by_value.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("clip_by_value") \ | |||
| .partial_flag(True) \ | |||
| .attr("dst_type", "required", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "clip_value_min", False, "required", "all") \ | |||
| .input(2, "clip_value_max", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ClipByValue", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "clip_by_value.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "clip_by_value", | |||
| "partial_flag": true, | |||
| "attr":[ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "clip_value_min", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "clip_value_max", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(clip_by_value_op_info) | |||
| def _clip_by_value_tbe(): | |||
| """ClipByValue TBE register""" | |||
| return | |||
| @@ -14,141 +14,44 @@ | |||
| # ============================================================================ | |||
| """Concat op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| concat_op_info = TBERegOp("Concat") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("concat_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("concat_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "required", "int", "all") \ | |||
| .input(0, "input_values", False, "dynamic", "all") \ | |||
| .output(0, "output_data", False, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I16_5HD, DataType.I16_5HD) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U16_5HD, DataType.U16_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U32_5HD, DataType.U32_5HD) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I64_5HD, DataType.I64_5HD) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.U64_5HD, DataType.U64_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Concat", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "concat_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "concat_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", | |||
| "float16", | |||
| "float", | |||
| "float", | |||
| "int32", | |||
| "int32", | |||
| "int8", | |||
| "int8", | |||
| "int16", | |||
| "int16", | |||
| "int64", | |||
| "int64", | |||
| "uint8", | |||
| "uint8", | |||
| "uint16", | |||
| "uint16", | |||
| "uint32", | |||
| "uint32", | |||
| "uint64", | |||
| "uint64", | |||
| "bool", | |||
| "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "input_values", | |||
| "need_compile": false, | |||
| "param_type": "dynamic", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", | |||
| "float16", | |||
| "float", | |||
| "float", | |||
| "int32", | |||
| "int32", | |||
| "int8", | |||
| "int8", | |||
| "int16", | |||
| "int16", | |||
| "int64", | |||
| "int64", | |||
| "uint8", | |||
| "uint8", | |||
| "uint16", | |||
| "uint16", | |||
| "uint32", | |||
| "uint32", | |||
| "uint64", | |||
| "uint64", | |||
| "bool", | |||
| "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "output_data", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(concat_op_info) | |||
| def _concat_tbe(): | |||
| """Concat TBE register""" | |||
| return | |||
| @@ -14,65 +14,28 @@ | |||
| # ============================================================================ | |||
| """ConfusionSoftmaxGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| confusion_softmax_grad_op_info = TBERegOp("ConfusionSoftmaxGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("confusion_softmax_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("confusion_softmax_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "grad", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ConfusionSoftmaxGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "confusion_softmax_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "confusion_softmax_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "grad", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(confusion_softmax_grad_op_info) | |||
| def _confusion_softmax_grad_tbe(): | |||
| """ConfusionSoftmaxGrad TBE register""" | |||
| return | |||
| @@ -14,79 +14,44 @@ | |||
| # ============================================================================ | |||
| """ConfusionTransposeD op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("confusion_transpose_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("confusion_transpose_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("perm", "required", "listInt", "all") \ | |||
| .attr("shape", "required", "listInt", "all") \ | |||
| .attr("transpose_first", "required", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ConfusionTransposeD", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "confusion_transpose_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "confusion_transpose_d", | |||
| "partial_flag": true, | |||
| "attr":[ | |||
| { | |||
| "name":"perm", | |||
| "param_type":"required", | |||
| "type":"listInt", | |||
| "value":"all" | |||
| }, | |||
| { | |||
| "name":"shape", | |||
| "param_type":"required", | |||
| "type":"listInt", | |||
| "value":"all" | |||
| }, | |||
| { | |||
| "name":"transpose_first", | |||
| "param_type":"required", | |||
| "type":"bool", | |||
| "value":"all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", | |||
| "uint64", "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", | |||
| "uint32", "uint64" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", | |||
| "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", | |||
| "uint64", "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", | |||
| "uint32", "uint64" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", | |||
| "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(confusion_transpose_d_op_info) | |||
| def _confusion_transpose_d_tbe(): | |||
| """ConfusionTransposeD TBE register""" | |||
| return | |||
| @@ -14,114 +14,30 @@ | |||
| # ============================================================================ | |||
| """Conv2D op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv2d_op_info = TBERegOp("Conv2D") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv2d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv2d") \ | |||
| .partial_flag(True) \ | |||
| .attr("stride", "required", "listInt", "all") \ | |||
| .attr("pad_list", "required", "listInt", "all") \ | |||
| .attr("dilation", "required", "listInt", "all") \ | |||
| .attr("offset_a", "optional", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "filter", False, "required", "all") \ | |||
| .input(2, "bias", False, "optional", "all") \ | |||
| .input(3, "offset_w", False, "optional", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default, | |||
| DataType.F16_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Conv2D", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "CONVLUTION", | |||
| "async_flag": false, | |||
| "binfile_name": "conv2d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "conv2d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "stride", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "pad_list", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "dilation", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "offset_a", | |||
| "param_type": "optional", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FracZ" | |||
| ], | |||
| "name": "filter", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "bias", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "int8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "offset_w", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(conv2d_op_info) | |||
| def _conv2d_tbe(): | |||
| """Conv2D TBE register""" | |||
| return | |||
| @@ -14,89 +14,27 @@ | |||
| # ============================================================================ | |||
| """Conv2DBackpropFilter op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv2d_backprop_filter_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv2d_backprop_filter_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("filter_sizes", "required", "listInt", "all") \ | |||
| .attr("stride", "required", "listInt", "all") \ | |||
| .attr("pad_mode", "required", "str", "all") \ | |||
| .attr("dilation", "required", "listInt", "all") \ | |||
| .input(0, "out_backprop", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| # map to tbe kernel name conv2d_backprop_filter_d | |||
| @op_info_register("""{ | |||
| "op_name": "Conv2DBackpropFilter", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "CONVLUTION", | |||
| "async_flag": false, | |||
| "binfile_name": "conv2d_backprop_filter_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "conv2d_backprop_filter_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "filter_sizes", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "stride", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "pad_mode", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "dilation", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "out_backprop", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32" | |||
| ], | |||
| "format": [ | |||
| "FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(conv2d_backprop_filter_op_info) | |||
| def _conv2d_backprop_filter_tbe(): | |||
| """Conv2DBackpropFilter TBE register""" | |||
| return | |||
| @@ -14,88 +14,27 @@ | |||
| # ============================================================================ | |||
| """Conv2DBackpropInput op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv2d_backprop_input_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv2d_backprop_input_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("input_sizes", "required", "listInt", "all") \ | |||
| .attr("stride", "required", "listInt", "all") \ | |||
| .attr("pad_mode", "required", "str", "all") \ | |||
| .attr("dilation", "required", "listInt", "all") \ | |||
| .input(0, "out_backprop", False, "required", "all") \ | |||
| .input(1, "filter", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Conv2DBackpropInput", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "CONVLUTION", | |||
| "async_flag": false, | |||
| "binfile_name": "conv2d_backprop_input_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "conv2d_backprop_input_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "input_sizes", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "stride", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "pad_mode", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "dilation", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "out_backprop", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FracZ" | |||
| ], | |||
| "name": "filter", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(conv2d_backprop_input_op_info) | |||
| def _conv2d_backprop_input_tbe(): | |||
| """Conv2DBackpropInput TBE register""" | |||
| return | |||
| @@ -14,71 +14,32 @@ | |||
| # ============================================================================ | |||
| """Div op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| div_op_info = TBERegOp("Div") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("div.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("div") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Div", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "div.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "div", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(div_op_info) | |||
| def _div_tbe(): | |||
| """Div TBE register""" | |||
| return | |||
| @@ -14,76 +14,25 @@ | |||
| # ============================================================================ | |||
| """DropoutdoMask op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("drop_out_do_mask.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("drop_out_do_mask") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "mask", False, "required", "all") \ | |||
| .input(2, "keep_prob", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "DropoutDoMask", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "drop_out_do_mask.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "drop_out_do_mask", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "uint8","uint8","uint8","uint8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "mask", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "keep_prob", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(drop_out_do_mask_op_info) | |||
| def _dropout_do_mask_tbe(): | |||
| """DropoutdoMask TBE register""" | |||
| return | |||
| @@ -14,66 +14,32 @@ | |||
| # ============================================================================ | |||
| """Equal op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| equal_op_info = TBERegOp("Equal") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("equal.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("equal") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Equal", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "equal.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "equal", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool","bool","bool","bool","bool","bool","bool","bool","bool","bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(equal_op_info) | |||
| def _equal_tbe(): | |||
| """Equal TBE register""" | |||
| return | |||
| @@ -14,52 +14,25 @@ | |||
| # ============================================================================ | |||
| """Exp op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| exp_op_info = TBERegOp("Exp") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("exp.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("exp") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Exp", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "exp.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "exp", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(exp_op_info) | |||
| def _exp_tbe(): | |||
| """Exp TBE register""" | |||
| return | |||
| @@ -14,57 +14,25 @@ | |||
| # ============================================================================ | |||
| """ExpandDims op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| expand_dims_op_info = TBERegOp("ExpandDims") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("expand_dims.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("expand_dims") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ExpandDims", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "expand_dims.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "expand_dims", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(expand_dims_op_info) | |||
| def _expand_dims_tbe(): | |||
| """ExpandDims TBE register""" | |||
| return | |||
| @@ -14,64 +14,27 @@ | |||
| # ============================================================================ | |||
| """FloorDiv op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| floordiv_op_info = TBERegOp("FloorDiv") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("floordiv.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("floordiv") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "FloorDiv", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "floordiv.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "floordiv", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(floordiv_op_info) | |||
| def _floor_div_tbe(): | |||
| """FloorDiv TBE register""" | |||
| return | |||
| @@ -14,93 +14,38 @@ | |||
| # ============================================================================ | |||
| """FusedMulAdd op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fused_mul_add_op_info = TBERegOp("FusedMulAdd") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fused_mul_add.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fused_mul_add") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | |||
| .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \ | |||
| .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "FusedMulAdd", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "fused_mul_add.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "fused_mul_add", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "float16", "float16", "float16", "float16", "float16", | |||
| "float", "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "float16", "float16", "float16", "float16", "float16", | |||
| "float", "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "float16", "float16", "float16", "float16", "float16", | |||
| "float", "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x3", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "float16", "float16", "float16", "float16", "float16", | |||
| "float", "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(fused_mul_add_op_info) | |||
| def _fused_mul_add_tbe(): | |||
| """FusedMulAdd TBE register""" | |||
| return | |||
| @@ -14,86 +14,31 @@ | |||
| # ============================================================================ | |||
| """FusedMulAddN op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fused_mul_add_n_op_info = TBERegOp("FusedMulAddN") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fused_mul_add_n.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fused_mul_add_n") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "FusedMulAddN", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "fused_mul_add_n.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "fused_mul_add_n", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x3", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(fused_mul_add_n_op_info) | |||
| def _fused_mul_add_n_tbe(): | |||
| """FusedMulAddN TBE register""" | |||
| return | |||
| @@ -14,137 +14,43 @@ | |||
| # ============================================================================ | |||
| """FusedMulApplyMomentum op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fused_mul_apply_momentum_op_info = TBERegOp("FusedMulApplyMomentum") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fused_mul_apply_momentum.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fused_mul_apply_momentum") \ | |||
| .partial_flag(True) \ | |||
| .attr("use_nesterov", "optional", "bool", "true,false", "false") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "accum", False, "required", "all") \ | |||
| .input(2, "lr", False, "required", "all") \ | |||
| .input(3, "x1", False, "required", "all") \ | |||
| .input(4, "momentum", False, "required", "all") \ | |||
| .input(5, "x2", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | |||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | |||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | |||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | |||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "FusedMulApplyMomentum", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "fused_mul_apply_momentum.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "fused_mul_apply_momentum", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "use_nesterov", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "true,false", | |||
| "default_value":"false" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "var", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "accum", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "lr", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "momentum", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 5, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16", | |||
| "float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ", | |||
| "NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ" | |||
| ], | |||
| "name": "var", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(fused_mul_apply_momentum_op_info) | |||
| def _fused_mul_apply_momentum_tbe(): | |||
| """FusedMulApplyMomentum TBE register""" | |||
| return | |||
| @@ -14,94 +14,53 @@ | |||
| # ============================================================================ | |||
| """AddN op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| gather_v2_op_info = TBERegOp("GatherV2") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("gather_v2_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("gather_v2_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "indices", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \ | |||
| .dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \ | |||
| .dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "GatherV2", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "gather_v2_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "gather_v2_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "optional", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float16","float16", | |||
| "float","float","float","float","float","float", | |||
| "int32","int32","int32", "int32","int32","int32", | |||
| "uint8","uint8","uint8","uint8","uint8","uint8", | |||
| "int8","int8", "int8","int8","int8", "int8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "int32","int32","int32","int64","int64","int64", | |||
| "int32","int32","int32","int64","int64","int64", | |||
| "int32","int32","int32","int64","int64","int64", | |||
| "int32","int32","int32","int64","int64","int64", | |||
| "int32","int32","int32","int64","int64","int64" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ" | |||
| ], | |||
| "name": "indices", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float16","float16","float16", | |||
| "float","float","float","float","float","float", | |||
| "int32","int32","int32", "int32","int32","int32", | |||
| "uint8","uint8","uint8","uint8","uint8","uint8", | |||
| "int8","int8", "int8","int8","int8", "int8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ", | |||
| "DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(gather_v2_op_info) | |||
| def _gather_v2_tbe(): | |||
| """GatherV2 TBE register""" | |||
| return | |||
| @@ -14,51 +14,29 @@ | |||
| # ============================================================================ | |||
| """Gelu op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| gelu_op_info = TBERegOp("Gelu") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("gelu.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("gelu") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Gelu", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "gelu.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "gelu", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","float16","float","float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","float16","float","float16","float16","float16","float16","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(gelu_op_info) | |||
| def _gelu_tbe(): | |||
| """Gelu TBE register""" | |||
| return | |||
| @@ -14,77 +14,29 @@ | |||
| # ============================================================================ | |||
| """GeluGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| gelu_grad_op_info = TBERegOp("GeluGrad") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("gelu_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("gelu_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "dy", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "y", False, "required", "all") \ | |||
| .output(0, "z", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "GeluGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "gelu_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "gelu_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "dy", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "z", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(gelu_grad_op_info) | |||
| def _gelu_grad_tbe(): | |||
| """GeluGrad TBE register""" | |||
| return | |||
| @@ -14,68 +14,32 @@ | |||
| # ============================================================================ | |||
| """Greater op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| greater_op_info = TBERegOp("Greater") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("greater.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("greater") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Greater", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "greater.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "greater", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool","bool","bool","bool","bool","bool","bool","bool","bool","bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(greater_op_info) | |||
| def _greater_tbe(): | |||
| """Greater TBE register""" | |||
| return | |||
| @@ -14,279 +14,46 @@ | |||
| # ============================================================================ | |||
| """LambNextMV op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lamb_next_mv_op_info = TBERegOp("LambNextMV") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lamb_next_m_v.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lamb_next_m_v") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input1", False, "required", "all") \ | |||
| .input(1, "input2", False, "required", "all") \ | |||
| .input(2, "input3", False, "required", "all") \ | |||
| .input(3, "input4", False, "required", "all") \ | |||
| .input(4, "input5", False, "required", "all") \ | |||
| .input(5, "input6", False, "required", "all") \ | |||
| .input(6, "input7", False, "required", "all") \ | |||
| .input(7, "input8", False, "required", "all") \ | |||
| .input(8, "input9", False, "required", "all") \ | |||
| .input(9, "inputx0", False, "required", "all") \ | |||
| .input(10, "inputx1", False, "required", "all") \ | |||
| .input(11, "inputx2", False, "required", "all") \ | |||
| .input(12, "inputx3", False, "required", "all") \ | |||
| .output(0, "output1", False, "required", "all") \ | |||
| .output(1, "output2", False, "required", "all") \ | |||
| .output(2, "output3", False, "required", "all") \ | |||
| .output(3, "output4", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"LambNextMV", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"ELEMWISE", | |||
| "async_flag":false, | |||
| "binfile_name":"lamb_next_m_v.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"lamb_next_m_v", | |||
| "partial_flag":true, | |||
| "attr":[], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":3, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input4", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":4, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input5", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":5, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input6", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":6, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input7", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":7, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input8", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":8, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input9", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":9, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx0", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":10, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":11, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":12, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":3, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output4", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(lamb_next_mv_op_info) | |||
| def _lamb_next_mv_tbe(): | |||
| """LambNextMV TBE register""" | |||
| return | |||
| @@ -14,279 +14,46 @@ | |||
| # ============================================================================ | |||
| """LambNextMVWithDecayV1 op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lamb_next_m_v_with_decay_v1.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lamb_next_m_v_with_decay_v1") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input1", False, "required", "all") \ | |||
| .input(1, "input2", False, "required", "all") \ | |||
| .input(2, "input3", False, "required", "all") \ | |||
| .input(3, "input4", False, "required", "all") \ | |||
| .input(4, "input5", False, "required", "all") \ | |||
| .input(5, "input6", False, "required", "all") \ | |||
| .input(6, "input7", False, "required", "all") \ | |||
| .input(7, "input8", False, "required", "all") \ | |||
| .input(8, "input9", False, "required", "all") \ | |||
| .input(9, "inputx0", False, "required", "all") \ | |||
| .input(10, "inputx1", False, "required", "all") \ | |||
| .input(11, "inputx2", False, "required", "all") \ | |||
| .input(12, "inputx3", False, "required", "all") \ | |||
| .output(0, "output1", False, "required", "all") \ | |||
| .output(1, "output2", False, "required", "all") \ | |||
| .output(2, "output3", False, "required", "all") \ | |||
| .output(3, "output4", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"LambNextMVWithDecayV1", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"OPAQUE", | |||
| "async_flag":false, | |||
| "binfile_name":"lamb_next_m_v_with_decay_v1.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"lamb_next_m_v_with_decay_v1", | |||
| "partial_flag":true, | |||
| "attr":[], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":3, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input4", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":4, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input5", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":5, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input6", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":6, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input7", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":7, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input8", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":8, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input9", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":9, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx0", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":10, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":11, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":12, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"inputx3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":3, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output4", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(lamb_next_m_v_with_decay_v1_op_info) | |||
| def _lamb_next_mv_with_decay_v1_tbe(): | |||
| """LambNextMVWithDecayV1 TBE register""" | |||
| return | |||
| @@ -14,174 +14,35 @@ | |||
| # ============================================================================ | |||
| """LambUpdateWithLr op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lamb_update_with_lr_op_info = TBERegOp("LambUpdateWithLR") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lamb_update_with_lr.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lamb_update_with_lr") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input1", False, "required", "all") \ | |||
| .input(1, "input2", False, "required", "all") \ | |||
| .input(2, "input3", False, "required", "all") \ | |||
| .input(3, "input4", False, "required", "all") \ | |||
| .input(4, "input5", False, "required", "all") \ | |||
| .input(5, "input6", False, "required", "all") \ | |||
| .input(6, "input7", False, "required", "all") \ | |||
| .input(7, "input8", False, "required", "all") \ | |||
| .input(8, "input9", False, "required", "all") \ | |||
| .output(0, "output_y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"LambUpdateWithLR", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"ELEMWISE", | |||
| "async_flag":false, | |||
| "binfile_name":"lamb_update_with_lr.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"lamb_update_with_lr", | |||
| "partial_flag":true, | |||
| "attr":[], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":3, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input4", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":4, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input5", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":5, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input6", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":6, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input7", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":7, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input8", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":8, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"input9", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"output_y", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(lamb_update_with_lr_op_info) | |||
| def _lamb_update_with_lr_tbe(): | |||
| """LambUpdateWithLr TBE register""" | |||
| return | |||
| @@ -14,144 +14,31 @@ | |||
| # ============================================================================ | |||
| """LambUpdateWithLrV2 op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lamb_update_with_lr_v2_op_info = TBERegOp("LambUpdateWithLrV2") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lamb_update_with_lr_v2.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lamb_update_with_lr_v2") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "required", "all") \ | |||
| .input(3, "x4", False, "required", "all") \ | |||
| .input(4, "x5", False, "required", "all") \ | |||
| .input(5, "greater_y", False, "required", "all") \ | |||
| .input(6, "select_e", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"LambUpdateWithLrV2", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"ELEMWISE", | |||
| "async_flag":false, | |||
| "binfile_name":"lamb_update_with_lr_v2.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"lamb_update_with_lr_v2", | |||
| "partial_flag":true, | |||
| "attr":[], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"x1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"x2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"x3", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":3, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"x4", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":4, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"x5", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":5, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"greater_y", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":6, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"select_e", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"y", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(lamb_update_with_lr_v2_op_info) | |||
| def _lamb_update_with_lr_v2_tbe(): | |||
| """LambUpdateWithLrV2 TBE register""" | |||
| return | |||
| @@ -14,111 +14,39 @@ | |||
| # ============================================================================ | |||
| """LayerNorm op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| layer_norm_op_info = TBERegOp("LayerNorm") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("layer_norm.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("layer_norm") \ | |||
| .partial_flag(True) \ | |||
| .attr("begin_norm_axis", "required", "int", "all") \ | |||
| .attr("begin_params_axis", "required", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "gamma", False, "required", "all") \ | |||
| .input(2, "beta", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "mean", False, "required", "all") \ | |||
| .output(2, "variance", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_FracNZ, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_FracNZ, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LayerNorm", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "layer_norm.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "layer_norm", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "begin_norm_axis", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "begin_params_axis", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "gamma", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "beta", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "mean", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "variance", | |||
| "param_type": "required" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(layer_norm_op_info) | |||
| def _layer_norm_tbe(): | |||
| """LayerNorm TBE register""" | |||
| return | |||
| @@ -14,105 +14,38 @@ | |||
| # ============================================================================ | |||
| """LayerNormBetaGammaBackprop op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("layer_norm_beta_gamma_backprop.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("layer_norm_beta_gamma_backprop") \ | |||
| .partial_flag(True) \ | |||
| .attr("shape_gamma", "required", "listInt", "all") \ | |||
| .input(0, "dy", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "variance", False, "required", "all") \ | |||
| .input(3, "mean", False, "required", "all") \ | |||
| .output(0, "pd_gamma", False, "required", "all") \ | |||
| .output(1, "pd_beta", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LayerNormBetaGammaBackprop", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "layer_norm_beta_gamma_backprop.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "layer_norm_beta_gamma_backprop", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "shape_gamma", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "dy", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "pd_gamma", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float","float","float","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "pd_beta", | |||
| "param_type": "required" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(layer_norm_beta_gamma_backprop_op_info) | |||
| def _layer_norm_beta_gamma_backprop_tbe(): | |||
| """LayerNormBetaGammaBackprop TBE register""" | |||
| return | |||
| @@ -14,124 +14,35 @@ | |||
| # ============================================================================ | |||
| """LayerNormGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| layer_norm_grad_op_info = TBERegOp("LayerNormGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("layer_norm_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("layer_norm_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "dy", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "variance", False, "required", "all") \ | |||
| .input(3, "mean", False, "required", "all") \ | |||
| .input(4, "gamma", False, "required", "all") \ | |||
| .output(0, "pd_x", False, "required", "all") \ | |||
| .output(1, "pd_gamma", False, "required", "all") \ | |||
| .output(2, "pd_beta", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LayerNormGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "layer_norm_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "layer_norm_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "dy", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "gamma", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "pd_x", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "pd_gamma", | |||
| "param_type": "required" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "pd_beta", | |||
| "param_type": "required" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(layer_norm_grad_op_info) | |||
| def _layer_norm_grad_tbe(): | |||
| """LayerNormGrad TBE register""" | |||
| return | |||
| @@ -14,102 +14,37 @@ | |||
| # ============================================================================ | |||
| """LayerNormXBackprop op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("layer_norm_x_backprop.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("layer_norm_x_backprop") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "dy", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "variance", False, "required", "all") \ | |||
| .input(3, "mean", False, "required", "all") \ | |||
| .input(4, "gamma", False, "required", "all") \ | |||
| .output(0, "pd_x", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LayerNormXBackprop", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "layer_norm_x_backprop.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "layer_norm_x_backprop", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "dy", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "variance", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 3, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "mean", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 4, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "gamma", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float16","float","float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "pd_x", | |||
| "param_type": "required" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(layer_norm_x_backprop_op_info) | |||
| def _layer_norm_x_backprop_tbe(): | |||
| """LayerNormXBackprop TBE register""" | |||
| return | |||
| @@ -14,67 +14,32 @@ | |||
| # ============================================================================ | |||
| """Less op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| less_op_info = TBERegOp("Less") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("less.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("less") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Less", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "less.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "less", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool","bool","bool","bool","bool","bool","bool","bool","bool","bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(less_op_info) | |||
| def _less_tbe(): | |||
| """Less TBE register""" | |||
| return | |||
| @@ -14,67 +14,34 @@ | |||
| # ============================================================================ | |||
| """LessEqual op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| less_equal_op_info = TBERegOp("LessEqual") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("less_equal.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("less_equal") \ | |||
| .partial_flag(True) \ | |||
| .attr("begin_norm_axis", "required", "int", "all") \ | |||
| .attr("begin_params_axis", "required", "int", "all") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LessEqual", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "less_equal.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "less_equal", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float","float","int32","int32","int8","int8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool","bool","bool","bool","bool","bool","bool","bool","bool","bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat", | |||
| "NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(less_equal_op_info) | |||
| def _less_equal_tbe(): | |||
| """LessEqual TBE register""" | |||
| return | |||
| @@ -14,52 +14,25 @@ | |||
| # ============================================================================ | |||
| """Log op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| log_op_info = TBERegOp("Log") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("log.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("log") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Log", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "log.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "log", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(log_op_info) | |||
| def _log_tbe(): | |||
| """Log TBE register""" | |||
| return | |||
| @@ -14,65 +14,26 @@ | |||
| # ============================================================================ | |||
| """LogicalAnd op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| logical_and_op_info = TBERegOp("LogicalAnd") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("logical_and.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("logical_and") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ, DataType.BOOL_FracZ) \ | |||
| .dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD, DataType.BOOL_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LogicalAnd", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "logical_and.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "logical_and", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(logical_and_op_info) | |||
| def _logical_and_tbe(): | |||
| """LogicalAnd TBE register""" | |||
| return | |||
| @@ -14,52 +14,25 @@ | |||
| # ============================================================================ | |||
| """LogicalNot op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| logical_not_op_info = TBERegOp("LogicalNot") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("logical_not.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("logical_not") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ) \ | |||
| .dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LogicalNot", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "logical_not.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "logical_not", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(logical_not_op_info) | |||
| def _logical_not_tbe(): | |||
| """LogicalNot TBE register""" | |||
| return | |||
| @@ -14,65 +14,26 @@ | |||
| # ============================================================================ | |||
| """LogicalOr op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| logical_or_op_info = TBERegOp("LogicalOr") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("logical_or.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("logical_or") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ, DataType.BOOL_FracZ) \ | |||
| .dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD, DataType.BOOL_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LogicalOr", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "logical_or.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "logical_or", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(logical_or_op_info) | |||
| def _logical_or_tbe(): | |||
| """LogicalOr TBE register""" | |||
| return | |||
| @@ -14,57 +14,24 @@ | |||
| # ============================================================================ | |||
| """LogSoftmax op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| log_softmax_op_info = TBERegOp("LogSoftmax") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("log_softmax.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("log_softmax") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "listInt", "all") \ | |||
| .input(0, "logits", False, "required", "all") \ | |||
| .output(0, "logsoftmax", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LogSoftmax", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "log_softmax.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "log_softmax", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "logits", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "logsoftmax", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(log_softmax_op_info) | |||
| def _logsoftmax_tbe(): | |||
| """LogSoftMaxGrad TBE register""" | |||
| return | |||
| @@ -14,70 +14,25 @@ | |||
| # ============================================================================ | |||
| """LogSoftmaxGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| log_softmax_grad_op_info = TBERegOp("LogSoftmaxGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("log_softmax_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("log_softmax_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "grad", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "LogSoftmaxGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "log_softmax_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "log_softmax_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "grad", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(log_softmax_grad_op_info) | |||
| def _logsoftmax_grad_tbe(): | |||
| """LogSoftMaxGrad TBE register""" | |||
| return | |||
| @@ -14,89 +14,29 @@ | |||
| # ============================================================================ | |||
| """MatMul op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| matmul_op_info = TBERegOp("MatMul") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matmul.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("matmul") \ | |||
| .partial_flag(True) \ | |||
| .attr("transpose_a", "required", "bool", "all") \ | |||
| .attr("transpose_b", "required", "bool", "all") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "MatMul", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "matmul.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "matmul", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "transpose_a", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "transpose_b", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float16","float","int32" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float16","float","int32" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float","float","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x3", | |||
| "need_compile": false, | |||
| "param_type": "optional", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","float","int32" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(matmul_op_info) | |||
| def _matmul_tbe(): | |||
| """Mul TBE register""" | |||
| return | |||
| @@ -14,74 +14,26 @@ | |||
| # ============================================================================ | |||
| """MaxPool op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| max_pool_op_info = TBERegOp("MaxPool") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("max_pool.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("max_pool") \ | |||
| .partial_flag(True) \ | |||
| .attr("ksize", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("padding", "required", "str", "all") \ | |||
| .attr("data_format", "required", "str", "all") \ | |||
| .input(0, "input_data", False, "required", "all") \ | |||
| .output(0, "output_data", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "MaxPool", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "max_pool.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "max_pool", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "ksize", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "strides", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "padding", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "data_format", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "input_data", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "output_data", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(max_pool_op_info) | |||
| def _max_pool_tbe(): | |||
| """MaxPool TBE register""" | |||
| return | |||
| @@ -14,93 +14,27 @@ | |||
| # ============================================================================ | |||
| """MaxPoolGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| max_pool_grad_op_info = TBERegOp("MaxPoolGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("max_pool_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("max_pool_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("ksize", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("padding", "required", "str", "all") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "grad", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "MaxPoolGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "max_pool_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "max_pool_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "ksize", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "strides", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "padding", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "grad", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(max_pool_grad_op_info) | |||
| def _max_pool_grad_tbe(): | |||
| """MaxPoolGrad TBE register""" | |||
| return | |||
| @@ -14,95 +14,28 @@ | |||
| # ============================================================================ | |||
| """MaxPoolGradWithArgmax op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| max_pool_grad_with_argmax_op_info = TBERegOp("MaxPoolGradWithArgmax") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("max_pool_grad_with_argmax.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("max_pool_grad_with_argmax") \ | |||
| .partial_flag(True) \ | |||
| .attr("ksize", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("padding", "required", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "grad", False, "required", "all") \ | |||
| .input(2, "argmax", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "MaxPoolGradWithArgmax", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "max_pool_grad_with_argmax.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "max_pool_grad_with_argmax", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "ksize", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "strides", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "padding", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "NC1HWC0" | |||
| ], | |||
| "name": "grad", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "uint16", "int64" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "NC1HWC0" | |||
| ], | |||
| "name": "argmax", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(max_pool_grad_with_argmax_op_info) | |||
| def _max_pool_grad_with_argmax_tbe(): | |||
| """MaxPoolGradWithArgmax TBE register""" | |||
| return | |||
| @@ -14,82 +14,26 @@ | |||
| # ============================================================================ | |||
| """MaxPoolWithArgmax op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| max_pool_with_argmax_op_info = TBERegOp("MaxPoolWithArgmax") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("max_pool_with_argmax.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("max_pool_with_argmax") \ | |||
| .partial_flag(True) \ | |||
| .attr("ksize", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("padding", "required", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "argmax", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U16_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "MaxPoolWithArgmax", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "CONVLUTION", | |||
| "async_flag": false, | |||
| "binfile_name": "max_pool_with_argmax.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "max_pool_with_argmax", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "ksize", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "strides", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "padding", | |||
| "param_type": "required", | |||
| "type": "str", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "uint16" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "argmax", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(max_pool_with_argmax_op_info) | |||
| def _max_pool_with_argmax_tbe(): | |||
| """MaxPoolWithArgmax TBE register""" | |||
| return | |||
| @@ -14,69 +14,28 @@ | |||
| # ============================================================================ | |||
| """Maximum op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| maximum_op_info = TBERegOp("Maximum") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("maximum.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("maximum") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"Maximum", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"ELEMWISE", | |||
| "async_flag":false, | |||
| "binfile_name":"maximum.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"maximum", | |||
| "partial_flag":true, | |||
| "attr":[], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"x1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"x2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"y", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(maximum_op_info) | |||
| def _maximum_tbe(): | |||
| """Maximum TBE register""" | |||
| return | |||
| @@ -14,112 +14,38 @@ | |||
| # ============================================================================ | |||
| """MaximumGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| maximum_grad_op_info = TBERegOp("MaximumGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("maximum_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("maximum_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("grad_x", "optional", "bool", "all") \ | |||
| .attr("grad_y", "optional", "bool", "all") \ | |||
| .input(0, "grads", False, "required", "all") \ | |||
| .input(1, "x1", False, "required", "all") \ | |||
| .input(2, "x2", False, "required", "all") \ | |||
| .output(0, "y1", False, "required", "all") \ | |||
| .output(1, "y2", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, | |||
| DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"MaximumGrad", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"OPAQUE", | |||
| "async_flag":false, | |||
| "binfile_name":"maximum_grad.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"maximum_grad", | |||
| "partial_flag":true, | |||
| "attr":[ | |||
| { | |||
| "name":"grad_x", | |||
| "param_type":"optional", | |||
| "type":"bool", | |||
| "value":"all" | |||
| }, | |||
| { | |||
| "name":"grad_y", | |||
| "param_type":"optional", | |||
| "type":"bool", | |||
| "value":"all" | |||
| } | |||
| ], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"grads", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"x1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"x2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"y1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"y2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(maximum_grad_op_info) | |||
| def _maximum_grad_tbe(): | |||
| """MaximumGrad TBE register""" | |||
| return | |||
| @@ -15,74 +15,28 @@ | |||
| """Minimum op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| minimum_op_info = TBERegOp("Minimum") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("minimum.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("minimum") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Minimum", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "minimum.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "minimum", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(minimum_op_info) | |||
| def _minimum_tbe(): | |||
| """Minimum TBE register""" | |||
| return | |||
| @@ -14,112 +14,38 @@ | |||
| # ============================================================================ | |||
| """MinimumGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| minimum_grad_op_info = TBERegOp("MinimumGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("minimum_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("minimum_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("grad_x", "optional", "bool", "all") \ | |||
| .attr("grad_y", "optional", "bool", "all") \ | |||
| .input(0, "grads", False, "required", "all") \ | |||
| .input(1, "x1", False, "required", "all") \ | |||
| .input(2, "x2", False, "required", "all") \ | |||
| .output(0, "y1", False, "required", "all") \ | |||
| .output(1, "y2", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, | |||
| DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"MinimumGrad", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"OPAQUE", | |||
| "async_flag":false, | |||
| "binfile_name":"minimum_grad.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"minimum_grad", | |||
| "partial_flag":true, | |||
| "attr":[ | |||
| { | |||
| "name":"grad_x", | |||
| "param_type":"optional", | |||
| "type":"bool", | |||
| "value":"all" | |||
| }, | |||
| { | |||
| "name":"grad_y", | |||
| "param_type":"optional", | |||
| "type":"bool", | |||
| "value":"all" | |||
| } | |||
| ], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"grads", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"x1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":2, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"x2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"y1", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| }, | |||
| { | |||
| "index":1, | |||
| "dtype":[ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name":"y2", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(minimum_grad_op_info) | |||
| def _minimum_grad_tbe(): | |||
| """MinimumGrad TBE register""" | |||
| return | |||
| @@ -14,77 +14,37 @@ | |||
| # ============================================================================ | |||
| """Mul op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| mul_op_info = TBERegOp("Mul") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("mul.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("mul") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "y", False, "required", "all") \ | |||
| .output(0, "output", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | |||
| .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \ | |||
| .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Mul", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "mul.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "mul", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "float16", "float16", "float16", "float16", "float16", | |||
| "float", "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "float16", "float16", "float16", "float16", "float16", | |||
| "float", "float", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32", | |||
| "float16", "float16", "float16", "float16", "float16", | |||
| "float", "float", "float", "float","float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0", | |||
| "FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0" | |||
| ], | |||
| "name": "output", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(mul_op_info) | |||
| def _mul_tbe(): | |||
| """Mul TBE register""" | |||
| return | |||
| @@ -14,51 +14,29 @@ | |||
| # ============================================================================ | |||
| """Neg op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| neg_op_info = TBERegOp("Neg") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("neg.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("neg") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Neg", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "neg.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "neg", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float","float","float16","float16","int32","int32","int8","int8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float","float","float16","float16","int32","int32","int8","int8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(neg_op_info) | |||
| def _neg_tbe(): | |||
| """Neg TBE register""" | |||
| return | |||
| @@ -14,39 +14,21 @@ | |||
| # ============================================================================ | |||
| """NPUAllocFloatStatus op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| npu_alloc_float_status_op_info = TBERegOp("NPUAllocFloatStatus") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("n_p_u_alloc_float_status.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("n_p_u_alloc_float_status") \ | |||
| .partial_flag(True) \ | |||
| .output(0, "data", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "NPUAllocFloatStatus", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "n_p_u_alloc_float_status.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "n_p_u_alloc_float_status", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "data", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(npu_alloc_float_status_op_info) | |||
| def _npu_alloc_float_status_tbe(): | |||
| """NPUAllocFloatStatus TBE register""" | |||
| return | |||
| @@ -14,52 +14,22 @@ | |||
| # ============================================================================ | |||
| """NPUClearFloatStatus op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| npu_clear_float_status_op_info = TBERegOp("NPUClearFloatStatus") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("n_p_u_clear_float_status.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("n_p_u_clear_float_status") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "addr", False, "required", "all") \ | |||
| .output(0, "data", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "NPUClearFloatStatus", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "n_p_u_clear_float_status.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "n_p_u_clear_float_status", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "addr", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "data", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(npu_clear_float_status_op_info) | |||
| def _npu_clear_float_status_tbe(): | |||
| """NPUClearFloatStatus TBE register""" | |||
| return | |||
| @@ -14,52 +14,22 @@ | |||
| # ============================================================================ | |||
| """NPUGetFloatStatus op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| npu_get_float_status_op_info = TBERegOp("NPUGetFloatStatus") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("n_p_u_get_float_status.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("n_p_u_get_float_status") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "addr", False, "required", "all") \ | |||
| .output(0, "data", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "NPUGetFloatStatus", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "n_p_u_get_float_status.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "n_p_u_get_float_status", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "addr", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat" | |||
| ], | |||
| "name": "data", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(npu_get_float_status_op_info) | |||
| def _npu_get_float_status_tbe(): | |||
| """NPUGetFloatStatus TBE register""" | |||
| return | |||
| @@ -14,96 +14,35 @@ | |||
| # ============================================================================ | |||
| """OneHot op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| one_hot_op_info = TBERegOp("OneHot") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("one_hot.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("one_hot") \ | |||
| .partial_flag(True) \ | |||
| .attr("depth", "required", "int", "all") \ | |||
| .attr("axis", "required", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "on_value", False, "required", "all") \ | |||
| .input(2, "off_value", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.U8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "OneHot", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "one_hot.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "one_hot", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "depth", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32","int32","int32","int32","int32", | |||
| "uint8","uint8","uint8","uint8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float32","int32","int8","uint8", | |||
| "float16","float32","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "on_value", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16","float32","int32","int8","uint8", | |||
| "float16","float32","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "off_value", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float32","int32","int8","uint8", | |||
| "float16","float32","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(one_hot_op_info) | |||
| def _one_hot_tbe(): | |||
| """OneHot TBE register""" | |||
| return | |||
| @@ -14,57 +14,27 @@ | |||
| # ============================================================================ | |||
| """Pad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| pad_d_op_info = TBERegOp("Pad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("pad_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("pad_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("paddings", "optional", "listListInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Pad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "pad_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "pad_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "paddings", | |||
| "param_type": "optional", | |||
| "type": "listListInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int8","uint8","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int8","uint8","int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(pad_d_op_info) | |||
| def _pad_d_tbe(): | |||
| """Pad TBE register""" | |||
| return | |||
| @@ -14,65 +14,27 @@ | |||
| # ============================================================================ | |||
| """Pow op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| pow_op_info = TBERegOp("Pow") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("pow.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("pow") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Pow", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "pow.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "pow", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int32", "int8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float", "int32", "int8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int32", "int8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(pow_op_info) | |||
| def _pow_tbe(): | |||
| """Pow TBE register""" | |||
| return | |||
| @@ -14,64 +14,26 @@ | |||
| # ============================================================================ | |||
| """RealDiv op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| realdiv_op_info = TBERegOp("RealDiv") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("realdiv.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("realdiv") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "y", False, "required", "all") \ | |||
| .output(0, "z", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "RealDiv", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "realdiv.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "realdiv", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "z", | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(realdiv_op_info) | |||
| def _real_div_tbe(): | |||
| """RealDiv TBE register""" | |||
| return | |||
| @@ -14,52 +14,27 @@ | |||
| # ============================================================================ | |||
| """Add op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reciprocal_op_info = TBERegOp("Reciprocal") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reciprocal.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reciprocal") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Reciprocal", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "reciprocal.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "reciprocal", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float32", "float32", "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "NHWC", "DefaultFormat", "NC1HWC0", "NHWC" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float32", "float32", "float32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "NHWC", "DefaultFormat", "NC1HWC0", "NHWC" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(reciprocal_op_info) | |||
| def _reciprocal_tbe(): | |||
| """Add TBE register""" | |||
| return | |||
| @@ -14,63 +14,29 @@ | |||
| # ============================================================================ | |||
| """ReduceMax op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reduce_max_d_op_info = TBERegOp("ReduceMax") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reduce_max_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reduce_max_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "listInt", "all") \ | |||
| .attr("keep_dims", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ReduceMax", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "reduce_max_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "reduce_max_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "keep_dims", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int8", "uint8", "bool", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float", "int8", "uint8", "bool", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(reduce_max_d_op_info) | |||
| def _reduce_max_tbe(): | |||
| """ReduceMax TBE register""" | |||
| return | |||
| @@ -14,63 +14,27 @@ | |||
| # ============================================================================ | |||
| """ReduceMean op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reduce_mean_op_info = TBERegOp("ReduceMean") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reduce_mean.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reduce_mean") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "listInt", "all") \ | |||
| .attr("keep_dims", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ReduceMean", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "reduce_mean.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "reduce_mean", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "keep_dims", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","float16","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","float16","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(reduce_mean_op_info) | |||
| def _reduce_mean_tbe(): | |||
| """ReduceMean TBE register""" | |||
| return | |||
| @@ -14,63 +14,27 @@ | |||
| # ============================================================================ | |||
| """ReduceMeanD op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reduce_mean_d_op_info = TBERegOp("ReduceMeanD") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reduce_mean_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reduce_mean_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "listInt", "all") \ | |||
| .attr("keep_dims", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ReduceMeanD", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "reduce_mean_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "reduce_mean_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "keep_dims", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float","float16","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float","float16","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(reduce_mean_d_op_info) | |||
| def _reduce_mean_d_tbe(): | |||
| """Conv2D TBE register""" | |||
| return | |||
| @@ -14,63 +14,31 @@ | |||
| # ============================================================================ | |||
| """ReduceMin op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reduce_min_op_info = TBERegOp("ReduceMin") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reduce_min_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reduce_min_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "required", "listInt", "all") \ | |||
| .attr("keep_dims", "required", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ReduceMin", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "reduce_min_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "reduce_min_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "keep_dims", | |||
| "param_type": "required", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(reduce_min_op_info) | |||
| def _reduce_min_tbe(): | |||
| """ReduceMin TBE register""" | |||
| return | |||
| @@ -14,63 +14,25 @@ | |||
| # ============================================================================ | |||
| """ReduceSum op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reduce_sum_op_info = TBERegOp("ReduceSum") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reduce_sum_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reduce_sum_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "listInt", "all") \ | |||
| .attr("keep_dims", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ReduceSum", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "reduce_sum_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "reduce_sum_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "keep_dims", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(reduce_sum_op_info) | |||
| def _reduce_sum_tbe(): | |||
| """ReduceSum TBE register""" | |||
| return | |||
| @@ -14,54 +14,29 @@ | |||
| # ============================================================================ | |||
| """ReLU op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| relu_op_info = TBERegOp("ReLU") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("relu.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("relu") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ReLU", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "relu.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "relu", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float","int32", "int32", "int8", "int8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int32", "int32", "int8", "int8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(relu_op_info) | |||
| def _relu_tbe(): | |||
| """Relu TBE register""" | |||
| return | |||
| @@ -14,68 +14,32 @@ | |||
| # ============================================================================ | |||
| """ReluGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| relugrad_op_info = TBERegOp("ReluGrad") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("relugrad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("relugrad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "gradients", False, "required", "all") \ | |||
| .input(1, "features", False, "required", "all") \ | |||
| .output(0, "backprops", True, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ReluGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "relugrad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "relugrad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0","DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "gradients", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "features", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" | |||
| ], | |||
| "name": "backprops", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(relugrad_op_info) | |||
| def _relu_grad_tbe(): | |||
| """ReluGrad TBE register""" | |||
| return | |||
| @@ -14,57 +14,25 @@ | |||
| # ============================================================================ | |||
| """Reshape op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reshape_op_info = TBERegOp("Reshape") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reshape.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reshape") \ | |||
| .partial_flag(True) \ | |||
| .attr("shape", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Reshape", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "reshape.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "reshape", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "shape", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(reshape_op_info) | |||
| def _reshape_tbe(): | |||
| """Reshape TBE register""" | |||
| return | |||
| @@ -14,67 +14,33 @@ | |||
| # ============================================================================ | |||
| """ResizeNearestNeighbor op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("resize_nearest_neighbor_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("resize_nearest_neighbor_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("size", "required", "listInt", "all") \ | |||
| .attr("align_corners", "optional", "bool", "all") \ | |||
| .input(0, "images", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ResizeNearestNeighbor", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "resize_nearest_neighbor_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "resize_nearest_neighbor_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "size", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "align_corners", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8", | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "images", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8", | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0", | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(resize_nearest_neighbor_op_info) | |||
| def _resize_nearest_neighbor_d_tbe(): | |||
| """ResizeNearestNeighbor TBE register""" | |||
| return | |||
| @@ -14,63 +14,28 @@ | |||
| # ============================================================================ | |||
| """ResizeNearestNeighbor op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| resize_nearest_neighbor_d_op_info = TBERegOp("ResizeNearestNeighbor") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("resize_nearest_neighbor_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("resize_nearest_neighbor_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("size", "required", "listInt", "all") \ | |||
| .attr("align_corners", "optional", "bool", "all") \ | |||
| .input(0, "images", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ResizeNearestNeighbor", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "resize_nearest_neighbor_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "resize_nearest_neighbor_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "size", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "align_corners", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "images", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(resize_nearest_neighbor_d_op_info) | |||
| def _resize_nearest_neighbor_d_tbe(): | |||
| """ResizeNearestNeighbor TBE register""" | |||
| return | |||
| @@ -14,63 +14,24 @@ | |||
| # ============================================================================ | |||
| """ResizeNearestNeighborgrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| resize_nearest_neighbor_grad_d_op_info = TBERegOp("ResizeNearestNeighborGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("resize_nearest_neighbor_grad_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("resize_nearest_neighbor_grad_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("size", "required", "listInt", "all") \ | |||
| .attr("align_corners", "optional", "bool", "all") \ | |||
| .input(0, "grads", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ResizeNearestNeighborGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "resize_nearest_neighbor_grad_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "resize_nearest_neighbor_grad_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "size", | |||
| "param_type": "required", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "align_corners", | |||
| "param_type": "optional", | |||
| "type": "bool", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "grads", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(resize_nearest_neighbor_grad_d_op_info) | |||
| def _resize_nearest_neighbor_grad_d_tbe(): | |||
| """ResizeNearestNeighborGrad TBE register""" | |||
| return | |||
| @@ -14,52 +14,27 @@ | |||
| # ============================================================================ | |||
| """Round op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| round_op_info = TBERegOp("Round") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("round.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("round") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Round", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "round.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "round", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(round_op_info) | |||
| def _round_tbe(): | |||
| """Round TBE register""" | |||
| return | |||
| @@ -14,94 +14,29 @@ | |||
| # ============================================================================ | |||
| """Rsqrt op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| rsqrt_op_info = TBERegOp("Rsqrt") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("rsqrt.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("rsqrt") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"Rsqrt", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"OPAQUE", | |||
| "async_flag":false, | |||
| "binfile_name":"rsqrt.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"rsqrt", | |||
| "partial_flag":true, | |||
| "attr":[], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float", | |||
| "float", | |||
| "float", | |||
| "float", | |||
| "float", | |||
| "float" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "FracZ", | |||
| "C1HWNCoC0", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "FracZ", | |||
| "C1HWNCoC0" | |||
| ], | |||
| "name":"x", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float16", | |||
| "float", | |||
| "float", | |||
| "float", | |||
| "float", | |||
| "float", | |||
| "float" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "FracZ", | |||
| "C1HWNCoC0", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "NC1HWC0", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "FracZ", | |||
| "C1HWNCoC0" | |||
| ], | |||
| "name":"y", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(rsqrt_op_info) | |||
| def _rsqrt_tbe(): | |||
| """Rsqrt TBE register""" | |||
| return | |||
| @@ -14,71 +14,28 @@ | |||
| # ============================================================================ | |||
| """ScatterNd op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| scatter_nd_op_info = TBERegOp("ScatterNd") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("scatter_nd_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("scatter_nd_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("shape", "optional", "listInt", "all") \ | |||
| .input(0, "indices", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| # map to tbe kernel name scatter_nd_d | |||
| @op_info_register("""{ | |||
| "op_name": "ScatterNd", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "scatter_nd_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "scatter_nd_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "shape", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "indices", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(scatter_nd_op_info) | |||
| def _scatter_nd_tbe(): | |||
| """Conv2D TBE register""" | |||
| return | |||
| @@ -14,70 +14,28 @@ | |||
| # ============================================================================ | |||
| """ScatterNdD op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| scatter_nd_d_op_info = TBERegOp("ScatterNdD") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("scatter_nd_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("scatter_nd_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("shape", "optional", "listInt", "all") \ | |||
| .input(0, "indices", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "ScatterNdD", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "scatter_nd_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "scatter_nd_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "shape", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "int32", "int32", "int32", "int32", "int32" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "indices", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","int32","int8","uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(scatter_nd_d_op_info) | |||
| def _scatter_nd_d_tbe(): | |||
| """ScatterNdD TBE register""" | |||
| return | |||
| @@ -14,94 +14,33 @@ | |||
| # ============================================================================ | |||
| """Select op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| select_op_info = TBERegOp("Select") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("select.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("select") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "condition", False, "required", "all") \ | |||
| .input(1, "x1", False, "required", "all") \ | |||
| .input(2, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Select", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "select.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "select", | |||
| "partial_flag": true, | |||
| "attr":[ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", | |||
| "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", | |||
| "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "condition", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", | |||
| "int32", "int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8", | |||
| "uint8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", | |||
| "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x1", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x2", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", | |||
| "int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", | |||
| "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(select_op_info) | |||
| def _select_tbe(): | |||
| """Select TBE register""" | |||
| return | |||
| @@ -14,67 +14,31 @@ | |||
| # ============================================================================ | |||
| """Sigmoid op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sigmoid_op_info = TBERegOp("Sigmoid") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sigmoid.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sigmoid") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Sigmoid", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "Sigmoid.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "sigmoid", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float", | |||
| "float16","float", | |||
| "float16","float", | |||
| "float16","float", | |||
| "float16","float" | |||
| ], | |||
| "format": [ | |||
| "FracZ","FracZ", | |||
| "FRACTAL_NZ","FRACTAL_NZ", | |||
| "C1HWNCoC0","C1HWNCoC0", | |||
| "NC1HWC0","NC1HWC0", | |||
| "DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float", | |||
| "float16","float", | |||
| "float16","float", | |||
| "float16","float", | |||
| "float16","float" | |||
| ], | |||
| "format": [ | |||
| "FracZ","FracZ", | |||
| "FRACTAL_NZ","FRACTAL_NZ", | |||
| "C1HWNCoC0","C1HWNCoC0", | |||
| "NC1HWC0","NC1HWC0", | |||
| "DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(sigmoid_op_info) | |||
| def _sigmoid_tbe(): | |||
| """Sigmoid TBE register""" | |||
| return | |||
| @@ -14,64 +14,26 @@ | |||
| # ============================================================================ | |||
| """SigmoidCrossEntropyWithLogits op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidCrossEntropyWithLogits") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sigmoid_cross_entropy_with_logits.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sigmoid_cross_entropy_with_logits") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "predict", False, "required", "all") \ | |||
| .input(1, "target", False, "required", "all") \ | |||
| .output(0, "loss", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "SigmoidCrossEntropyWithLogits", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "sigmoid_cross_entropy_with_logits.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "sigmoid_cross_entropy_with_logits", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat" | |||
| ], | |||
| "name": "predict", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat" | |||
| ], | |||
| "name": "target", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat" | |||
| ], | |||
| "name": "loss", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(sigmoid_cross_entropy_with_logits_op_info) | |||
| def _sigmoid_cross_entropy_with_logits_tbe(): | |||
| """SigmoidCrossEntropyWithLogits TBE register""" | |||
| return | |||
| @@ -14,77 +14,27 @@ | |||
| # ============================================================================ | |||
| """SigmoidCrossEntropyWithLogitsGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sigmoid_cross_entropy_with_logits_grad_op_info = TBERegOp("SigmoidCrossEntropyWithLogitsGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sigmoid_cross_entropy_with_logits_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sigmoid_cross_entropy_with_logits_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "predict", False, "required", "all") \ | |||
| .input(1, "target", False, "required", "all") \ | |||
| .input(2, "dout", False, "required", "all") \ | |||
| .output(0, "gradient", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "SigmoidCrossEntropyWithLogitsGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "sigmoid_cross_entropy_with_logits_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "sigmoid_cross_entropy_with_logits_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat" | |||
| ], | |||
| "name": "predict", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat" | |||
| ], | |||
| "name": "target", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 2, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat" | |||
| ], | |||
| "name": "dout", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat" | |||
| ], | |||
| "name": "gradient", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(sigmoid_cross_entropy_with_logits_grad_op_info) | |||
| def _sigmoid_cross_entropy_with_logits_grad_tbe(): | |||
| """SigmoidCrossEntropyWithLogitsGrad TBE register""" | |||
| return | |||
| @@ -14,64 +14,26 @@ | |||
| # ============================================================================ | |||
| """SigmoidGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sigmoid_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sigmoid_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "y", False, "required", "all") \ | |||
| .output(0, "z", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "SigmoidGrad", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "sigmoid_grad.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "sigmoid_grad", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","float16","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16","float","float16","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16","float","float16","float" | |||
| ], | |||
| "format": [ | |||
| "NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" | |||
| ], | |||
| "name": "z", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(sigmoid_cross_entropy_with_logits_op_info) | |||
| def _sigmoid_grad_tbe(): | |||
| """SigmoidGrad TBE register""" | |||
| return | |||
| @@ -14,99 +14,33 @@ | |||
| # ============================================================================ | |||
| """Slice op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| slice_op_info = TBERegOp("Slice") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("slice_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("slice_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("begin", "required", "listInt", "all") \ | |||
| .attr("size", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name":"Slice", | |||
| "imply_type":"TBE", | |||
| "fusion_type":"OPAQUE", | |||
| "async_flag":false, | |||
| "binfile_name":"slice_d.so", | |||
| "compute_cost":10, | |||
| "kernel_name":"slice_d", | |||
| "partial_flag":true, | |||
| "attr":[ | |||
| { | |||
| "name":"begin", | |||
| "param_type":"required", | |||
| "type":"listInt", | |||
| "value":"all" | |||
| }, | |||
| { | |||
| "name":"size", | |||
| "param_type":"required", | |||
| "type":"listInt", | |||
| "value":"all" | |||
| } | |||
| ], | |||
| "inputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float", | |||
| "float16", | |||
| "int8", | |||
| "int16", | |||
| "int32", | |||
| "int64", | |||
| "uint8", | |||
| "uint16", | |||
| "uint32", | |||
| "uint64" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"x", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ], | |||
| "outputs":[ | |||
| { | |||
| "index":0, | |||
| "dtype":[ | |||
| "float", | |||
| "float16", | |||
| "int8", | |||
| "int16", | |||
| "int32", | |||
| "int64", | |||
| "uint8", | |||
| "uint16", | |||
| "uint32", | |||
| "uint64" | |||
| ], | |||
| "format":[ | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat", | |||
| "DefaultFormat" | |||
| ], | |||
| "name":"y", | |||
| "need_compile":false, | |||
| "param_type":"required", | |||
| "shape":"all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(slice_op_info) | |||
| def _slice_tbe(): | |||
| """Slice TBE register""" | |||
| return | |||
| @@ -14,57 +14,27 @@ | |||
| # ============================================================================ | |||
| """Softmax op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| softmax_op_info = TBERegOp("Softmax") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("softmax.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("softmax") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "optional", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Softmax", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "softmax.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "softmax", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "optional", | |||
| "type": "listInt", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat" | |||
| ], | |||
| "name": "x", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16", "float16", "float", "float" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat" | |||
| ], | |||
| "name": "y", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(softmax_op_info) | |||
| def _softmax_tbe(): | |||
| """Softmax TBE register""" | |||
| return | |||
| @@ -14,78 +14,25 @@ | |||
| # ============================================================================ | |||
| """SoftmaxCrossEntropyWithLogits op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| softmax_cross_entropy_with_logits_op_info = TBERegOp("SoftmaxCrossEntropyWithLogits") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("softmax_cross_entropy_with_logits.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("softmax_cross_entropy_with_logits") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input_features", False, "required", "all") \ | |||
| .input(1, "input_labels", False, "required", "all") \ | |||
| .output(0, "output_loss", True, "required", "all") \ | |||
| .output(1, "output_backprop", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "SoftmaxCrossEntropyWithLogits", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "OPAQUE", | |||
| "async_flag": false, | |||
| "binfile_name": "softmax_cross_entropy_with_logits.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "softmax_cross_entropy_with_logits", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "input_features", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "input_labels", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output_loss", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16", "float" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output_backprop", | |||
| "need_compile": true, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(softmax_cross_entropy_with_logits_op_info) | |||
| def _softmax_cross_entropy_with_logits_tbe(): | |||
| """SoftmaxCrossEntropyWithLogits TBE register""" | |||
| return | |||
| @@ -14,71 +14,45 @@ | |||
| # ============================================================================ | |||
| """Add op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| split_d_op_info = TBERegOp("Split") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("split_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("split_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "required", "int", "all") \ | |||
| .attr("output_num", "required", "int", "all") \ | |||
| .input(0, "value", False, "required", "all") \ | |||
| .output(0, "output", False, "dynamic", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ | |||
| .get_op_info() | |||
| @op_info_register("""{ | |||
| "op_name": "Split", | |||
| "imply_type": "TBE", | |||
| "fusion_type": "ELEMWISE", | |||
| "async_flag": false, | |||
| "binfile_name": "split_d.so", | |||
| "compute_cost": 10, | |||
| "kernel_name": "split_d", | |||
| "partial_flag": true, | |||
| "attr": [ | |||
| { | |||
| "name": "axis", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| }, | |||
| { | |||
| "name": "output_num", | |||
| "param_type": "required", | |||
| "type": "int", | |||
| "value": "all" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16","float32", "float32", "int32", "int32", "int8", "int8", | |||
| "int16", "int16", "int64", "int64", "uint8", "uint8", "uint16", "uint16", | |||
| "uint32", "uint32", "uint64", "uint64", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC" | |||
| , "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC" | |||
| , "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC" | |||
| ], | |||
| "name": "value", | |||
| "need_compile": false, | |||
| "param_type": "required", | |||
| "shape": "all" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16", "float16","float32", "float32", "int32", "int32", "int8", "int8", | |||
| "int16", "int16", "int64", "int64", "uint8", "uint8", "uint16", "uint16", | |||
| "uint32", "uint32", "uint64", "uint64", "bool", "bool" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC" | |||
| , "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC" | |||
| , "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC" | |||
| ], | |||
| "name": "output", | |||
| "need_compile": false, | |||
| "param_type": "dynamic", | |||
| "shape": "all" | |||
| } | |||
| ] | |||
| }""") | |||
| @op_info_register(split_d_op_info) | |||
| def _split_d_tbe(): | |||
| """Add TBE register""" | |||
| return | |||