|
|
@@ -13,33 +13,33 @@ |
|
|
# limitations under the License. |
|
|
# limitations under the License. |
|
|
# ============================================================================ |
|
|
# ============================================================================ |
|
|
|
|
|
|
|
|
"""LambNextMVWithDecayV1 op""" |
|
|
|
|
|
|
|
|
"""LambNextMVWithDecay op""" |
|
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType |
|
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType |
|
|
|
|
|
|
|
|
lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \ |
|
|
|
|
|
|
|
|
lamb_next_m_v_with_decay_op_info = TBERegOp("LambNextMVWithDecay") \ |
|
|
.fusion_type("OPAQUE") \ |
|
|
.fusion_type("OPAQUE") \ |
|
|
.async_flag(False) \ |
|
|
.async_flag(False) \ |
|
|
.binfile_name("lamb_next_m_v_with_decay_v1.so") \ |
|
|
|
|
|
|
|
|
.binfile_name("lamb_next_m_v_with_decay.so") \ |
|
|
.compute_cost(10) \ |
|
|
.compute_cost(10) \ |
|
|
.kernel_name("lamb_next_m_v_with_decay_v1") \ |
|
|
|
|
|
|
|
|
.kernel_name("lamb_next_m_v_with_decay") \ |
|
|
.partial_flag(True) \ |
|
|
.partial_flag(True) \ |
|
|
.input(0, "input1", False, "required", "all") \ |
|
|
|
|
|
.input(1, "input2", False, "required", "all") \ |
|
|
|
|
|
.input(2, "input3", False, "required", "all") \ |
|
|
|
|
|
.input(3, "input4", False, "required", "all") \ |
|
|
|
|
|
.input(4, "input5", False, "required", "all") \ |
|
|
|
|
|
.input(5, "input6", False, "required", "all") \ |
|
|
|
|
|
.input(6, "input7", False, "required", "all") \ |
|
|
|
|
|
.input(7, "input8", False, "required", "all") \ |
|
|
|
|
|
.input(8, "input9", False, "required", "all") \ |
|
|
|
|
|
.input(9, "inputx0", False, "required", "all") \ |
|
|
|
|
|
.input(10, "inputx1", False, "required", "all") \ |
|
|
|
|
|
.input(11, "inputx2", False, "required", "all") \ |
|
|
|
|
|
.input(12, "inputx3", False, "required", "all") \ |
|
|
|
|
|
.output(0, "output1", False, "required", "all") \ |
|
|
|
|
|
.output(1, "output2", False, "required", "all") \ |
|
|
|
|
|
.output(2, "output3", False, "required", "all") \ |
|
|
|
|
|
.output(3, "output4", False, "required", "all") \ |
|
|
|
|
|
|
|
|
.input(0, "input_mul3", False, "required", "all") \ |
|
|
|
|
|
.input(1, "input_mul2", False, "required", "all") \ |
|
|
|
|
|
.input(2, "input_realdiv1", False, "required", "all") \ |
|
|
|
|
|
.input(3, "input_mul1", False, "required", "all") \ |
|
|
|
|
|
.input(4, "input_mul0", False, "required", "all") \ |
|
|
|
|
|
.input(5, "input_realdiv0", False, "required", "all") \ |
|
|
|
|
|
.input(6, "input_mul4", False, "required", "all") \ |
|
|
|
|
|
.input(7, "mul0_x", False, "required", "all") \ |
|
|
|
|
|
.input(8, "mul1_sub", False, "required", "all") \ |
|
|
|
|
|
.input(9, "mul2_x", False, "required", "all") \ |
|
|
|
|
|
.input(10, "mul3_sub1", False, "required", "all") \ |
|
|
|
|
|
.input(11, "mul4_x", False, "required", "all") \ |
|
|
|
|
|
.input(12, "add2_y", False, "required", "all") \ |
|
|
|
|
|
.output(0, "y1", True, "required", "all") \ |
|
|
|
|
|
.output(1, "y2", True, "required", "all") \ |
|
|
|
|
|
.output(2, "y3", True, "required", "all") \ |
|
|
|
|
|
.output(3, "y4", True, "required", "all") \ |
|
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, |
|
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, |
|
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, |
|
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, |
|
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, |
|
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, |
|
|
@@ -53,7 +53,7 @@ lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \ |
|
|
.get_op_info() |
|
|
.get_op_info() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(lamb_next_m_v_with_decay_v1_op_info) |
|
|
|
|
|
def _lamb_next_mv_with_decay_v1_tbe(): |
|
|
|
|
|
"""LambNextMVWithDecayV1 TBE register""" |
|
|
|
|
|
|
|
|
@op_info_register(lamb_next_m_v_with_decay_op_info) |
|
|
|
|
|
def _lamb_next_mv_with_decay_tbe(): |
|
|
|
|
|
"""LambNextMVWithDecay TBE register""" |
|
|
return |
|
|
return |