Merge pull request !84 from zhaozhenlong/adapt-apply-adam-dtags/v0.3.0-alpha
| @@ -1 +1 @@ | |||||
| Subproject commit 976d1e31b777d65f87333c3a125093946e682a6e | |||||
| Subproject commit 1ab4fa8eb55b4f98e9e5e871a54909a1eaedffd3 | |||||
| @@ -391,6 +391,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; | {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; | ||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| adpt_map[string(kNamePrint)] = ADPT_DESC(Print); | adpt_map[string(kNamePrint)] = ADPT_DESC(Print); | ||||
| adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); | |||||
| #endif | #endif | ||||
| return adpt_map; | return adpt_map; | ||||
| } | } | ||||
| @@ -468,6 +468,15 @@ ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>()) | |||||
| {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}}; | {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; | OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; | ||||
| // ApplyAdamD | |||||
| INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, | |||||
| {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, | |||||
| {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, | |||||
| {10, INPUT_DESC(grad)}}; | |||||
| ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}, | |||||
| {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; | |||||
| // Relu6 | // Relu6 | ||||
| INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; | ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; | ||||
| @@ -124,6 +124,8 @@ DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) | DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) | ||||
| DECLARE_OP_ADAPTER(ApplyAdam) | DECLARE_OP_ADAPTER(ApplyAdam) | ||||
| DECLARE_OP_USE_OUTPUT(ApplyAdam) | DECLARE_OP_USE_OUTPUT(ApplyAdam) | ||||
| DECLARE_OP_ADAPTER(ApplyAdamD) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyAdamD) | |||||
| DECLARE_OP_ADAPTER(Relu6) | DECLARE_OP_ADAPTER(Relu6) | ||||
| DECLARE_OP_USE_OUTPUT(Relu6) | DECLARE_OP_USE_OUTPUT(Relu6) | ||||
| DECLARE_OP_ADAPTER(Relu6Grad) | DECLARE_OP_ADAPTER(Relu6Grad) | ||||
| @@ -2323,7 +2323,11 @@ class Adam(PrimitiveWithInfer): | |||||
| - **gradient** (Tensor) - Gradients. | - **gradient** (Tensor) - Gradients. | ||||
| Outputs: | Outputs: | ||||
| Tuple of 3 Tensor, the updated parameters. | |||||
| - **var** (Tensor) - The same shape and data type as `var`. | - **var** (Tensor) - The same shape and data type as `var`. | ||||
| - **m** (Tensor) - The same shape and data type as `m`. | |||||
| - **v** (Tensor) - The same shape and data type as `v`. | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -2336,7 +2340,7 @@ class Adam(PrimitiveWithInfer): | |||||
| validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | ||||
| validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | ||||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | ||||
| return var_shape | |||||
| return var_shape, m_shape, v_shape | |||||
| def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, | def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, | ||||
| beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): | beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): | ||||
| @@ -2346,7 +2350,7 @@ class Adam(PrimitiveWithInfer): | |||||
| args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, | args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, | ||||
| "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} | "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} | ||||
| validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) | validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) | ||||
| return var_dtype | |||||
| return var_dtype, m_dtype, v_dtype | |||||
| class BinaryCrossEntropy(PrimitiveWithInfer): | class BinaryCrossEntropy(PrimitiveWithInfer): | ||||