Browse Source

change adam output numbers adapter to tbe

tags/v0.2.0-alpha
zhaoting 5 years ago
parent
commit
fa03a66433
2 changed files with 41 additions and 9 deletions
  1. +39
    -7
      mindspore/ops/_op_impl/tbe/apply_adam.py
  2. +2
    -2
      mindspore/ops/operations/nn_ops.py

+ 39
- 7
mindspore/ops/_op_impl/tbe/apply_adam.py View File

@@ -88,7 +88,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta1_power",
"need_compile": false,
@@ -101,7 +102,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta2_power",
"need_compile": false,
@@ -114,7 +116,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "lr",
"need_compile": false,
@@ -127,7 +130,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta1",
"need_compile": false,
@@ -140,7 +144,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta2",
"need_compile": false,
@@ -153,7 +158,8 @@ from mindspore.ops.op_info_register import op_info_register
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "epsilon",
"need_compile": false,
@@ -161,7 +167,7 @@ from mindspore.ops.op_info_register import op_info_register
"shape": "all"
},
{
"index": 8,
"index": 9,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
@@ -187,6 +193,32 @@ from mindspore.ops.op_info_register import op_info_register
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "m",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "v",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")


+ 2
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -2149,7 +2149,7 @@ class Adam(PrimitiveWithInfer):
validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape)
validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape)
validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape)
return var_shape
return var_shape, m_shape, v_shape

def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
@@ -2159,7 +2159,7 @@ class Adam(PrimitiveWithInfer):
args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype,
"beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype}
validator.check_type_same(args, [mstype.float16, mstype.float32])
return var_dtype
return var_dtype, m_dtype, v_dtype


class BinaryCrossEntropy(PrimitiveWithInfer):


Loading…
Cancel
Save