| @@ -273,6 +273,7 @@ inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shared<Primitive | |||
| // NN | |||
| inline const PrimitivePtr kPrimCeLU = std::make_shared<Primitive>("CeLU"); | |||
| inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam"); | |||
| inline const PrimitivePtr kPrimApplyAdaMax = std::make_shared<Primitive>("ApplyAdaMax"); | |||
| inline const PrimitivePtr kPrimAudioSpectrogram = std::make_shared<Primitive>("AudioSpectrogram"); | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| inline const PrimitivePtr kPrimCrop = std::make_shared<Primitive>("Crop"); | |||
| @@ -0,0 +1,164 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/apply_ada_max.h" | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/tensor_construct_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::TupleShapePtr ApplyAdaMaxInferShape(const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| const int64_t kInputNum = 9; | |||
| (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum, | |||
| primitive->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto prim_name = primitive->name(); | |||
| auto var_shape = input_args[kInputIndex0]->BuildShape(); | |||
| auto m_shape = input_args[kInputIndex1]->BuildShape(); | |||
| auto v_shape = input_args[kInputIndex2]->BuildShape(); | |||
| auto var_shape_ptr = var_shape->cast<abstract::ShapePtr>(); | |||
| auto m_shape_ptr = m_shape->cast<abstract::ShapePtr>(); | |||
| auto v_shape_ptr = v_shape->cast<abstract::ShapePtr>(); | |||
| auto beta1_power_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; | |||
| auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape]; | |||
| auto beta1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape]; | |||
| auto beta2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape]; | |||
| auto epsilon_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape]; | |||
| auto grad_shape = input_args[kInputIndex8]->BuildShape(); | |||
| auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>(); | |||
| // beta1_power,lr,beta1,beta2,epsilon must be scalar | |||
| const int64_t kInputShape = 1; | |||
| (void)CheckAndConvertUtils::CheckInteger("beta1 power's rank", beta1_power_shape.size(), kLessEqual, kInputShape, | |||
| prim_name); | |||
| if (beta1_power_shape.size() == 1) { | |||
| (void)CheckAndConvertUtils::CheckInteger("beta1_power_shape[0]", beta1_power_shape.size(), kEqual, kInputShape, | |||
| prim_name); | |||
| } | |||
| (void)CheckAndConvertUtils::CheckInteger("lr's rank", lr_shape.size(), kLessEqual, kInputShape, prim_name); | |||
| if (lr_shape.size() == 1) { | |||
| (void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape.size(), kEqual, kInputShape, prim_name); | |||
| } | |||
| (void)CheckAndConvertUtils::CheckInteger("beta1's rank", beta1_shape.size(), kLessEqual, kInputShape, prim_name); | |||
| if (beta1_shape.size() == 1) { | |||
| (void)CheckAndConvertUtils::CheckInteger("beta1_shape[0]", beta1_shape.size(), kEqual, kInputShape, prim_name); | |||
| } | |||
| (void)CheckAndConvertUtils::CheckInteger("beta2's rank", beta2_shape.size(), kLessEqual, kInputShape, prim_name); | |||
| if (beta2_shape.size() == 1) { | |||
| (void)CheckAndConvertUtils::CheckInteger("beta2_shape[0]", beta2_shape.size(), kEqual, kInputShape, prim_name); | |||
| } | |||
| (void)CheckAndConvertUtils::CheckInteger("epsilon's rank", epsilon_shape.size(), kLessEqual, kInputShape, prim_name); | |||
| if (epsilon_shape.size() == 1) { | |||
| (void)CheckAndConvertUtils::CheckInteger("epsilon_shape[0]", epsilon_shape.size(), kEqual, kInputShape, prim_name); | |||
| } | |||
| // var, m,v and grad must have the same shape | |||
| std::map<std::string, abstract::BaseShapePtr> same_shape_args_map; | |||
| same_shape_args_map.insert({"m", m_shape}); | |||
| same_shape_args_map.insert({"v", v_shape}); | |||
| same_shape_args_map.insert({"grad", grad_shape}); | |||
| if (!var_shape_ptr->IsDynamic() && !m_shape_ptr->IsDynamic()) { | |||
| if (*m_shape != *var_shape) { | |||
| MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg m shape " << m_shape->ToString() | |||
| << " are not consistent with var shape " << var_shape->ToString(); | |||
| } | |||
| } | |||
| if (!v_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) { | |||
| if (*v_shape != *var_shape) { | |||
| MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg v shape " << v_shape->ToString() | |||
| << " are not consistent with var shape " << var_shape->ToString(); | |||
| } | |||
| } | |||
| if (!grad_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) { | |||
| if (*grad_shape != *var_shape) { | |||
| MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg grad shape " << grad_shape->ToString() | |||
| << " are not consistent with var shape " << var_shape->ToString(); | |||
| } | |||
| } | |||
| return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape}); | |||
| } | |||
| TuplePtr ApplyAdaMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto prim_name = prim->name(); | |||
| const int64_t kInputNum = 9; | |||
| (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum, | |||
| prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto var_type = input_args[kInputIndex0]->BuildType(); | |||
| auto m_type = input_args[kInputIndex1]->BuildType(); | |||
| auto v_type = input_args[kInputIndex2]->BuildType(); | |||
| auto beta1_power_type = input_args[kInputIndex3]->BuildType(); | |||
| auto lr_type = input_args[kInputIndex4]->BuildType(); | |||
| auto beta1_type = input_args[kInputIndex5]->BuildType(); | |||
| auto beta2_type = input_args[kInputIndex6]->BuildType(); | |||
| auto epsilon_type = input_args[kInputIndex7]->BuildType(); | |||
| auto grad_type = input_args[kInputIndex8]->BuildType(); | |||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | |||
| // m v grad must have the same type as var | |||
| std::map<std::string, TypePtr> args; | |||
| (void)args.insert({"var_type", var_type}); | |||
| (void)args.insert({"m_type", m_type}); | |||
| (void)args.insert({"v_type", v_type}); | |||
| (void)args.insert({"grad_type", grad_type}); | |||
| (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); | |||
| std::map<std::string, TypePtr> args_beta1_power; | |||
| std::map<std::string, TypePtr> args_lr; | |||
| std::map<std::string, TypePtr> args_beta1; | |||
| std::map<std::string, TypePtr> args_beta2; | |||
| std::map<std::string, TypePtr> args_epsilon; | |||
| (void)args_beta1_power.insert({"beta1_power_type", beta1_power_type}); | |||
| (void)args_lr.insert({"lr_type", lr_type}); | |||
| (void)args_beta1.insert({"beta1_type", beta1_type}); | |||
| (void)args_beta2.insert({"beta2_type", beta2_type}); | |||
| (void)args_epsilon.insert({"epsilon_type", epsilon_type}); | |||
| // beta1_power,lr,beta1,beta2,epsilon must be a scalar or zero dimension tensor type | |||
| (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1_power, valid_types, prim_name); | |||
| (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name); | |||
| (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1, valid_types, prim_name); | |||
| (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta2, valid_types, prim_name); | |||
| (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_epsilon, valid_types, prim_name); | |||
| return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type}); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto infer_type = ApplyAdaMaxInferType(primitive, input_args); | |||
| auto infer_shape = ApplyAdaMaxInferShape(primitive, input_args); | |||
| return abstract::MakeAbstract(infer_shape, infer_type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdaMax, prim::kPrimApplyAdaMax, ApplyAdaMaxInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_ | |||
| #define MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameApplyAdaMax = "ApplyAdaMax"; | |||
| class ApplyAdaMax : public PrimitiveC { | |||
| public: | |||
| ApplyAdaMax() : PrimitiveC(kNameApplyAdaMax) { | |||
| InitIOName({"var", "m", "v", "beta1_power", "lr", "beta1", "beta2", "epsilon", "grad"}, {"var", "m", "v"}); | |||
| } | |||
| ~ApplyAdaMax() = default; | |||
| MS_DECLARE_PARENT(ApplyAdaMax, PrimitiveC); | |||
| }; | |||
| AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using kPrimApplyAdaMaxPtr = std::shared_ptr<ApplyAdaMax>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_ | |||
| @@ -35,6 +35,7 @@ from .apply_keras_momentum import _apply_keras_momentum_tbe | |||
| from .apply_momentum import _apply_momentum_tbe | |||
| from .apply_adam import _apply_adam_tbe | |||
| from .apply_ada_max import _apply_ada_max_tbe | |||
| from .apply_ada_max_ds import _apply_ada_max_ds_tbe | |||
| from .apply_adadelta import _apply_adadelta_tbe | |||
| from .apply_adagrad import _apply_adagrad_tbe | |||
| from .apply_adagrad_v2 import _apply_adagrad_v2_tbe | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ApplyAdaMaxD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| apply_ada_max_ds_op_info = TBERegOp("ApplyAdaMax") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("apply_ada_max_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("apply_ada_max_d") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True)\ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "m", False, "required", "all") \ | |||
| .input(2, "v", False, "required", "all") \ | |||
| .input(3, "beta1_power", False, "required", "all") \ | |||
| .input(4, "lr", False, "required", "all") \ | |||
| .input(5, "beta1", False, "required", "all") \ | |||
| .input(6, "beta2", False, "required", "all") \ | |||
| .input(7, "epsilon", False, "required", "all") \ | |||
| .input(8, "grad", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "m", False, "required", "all") \ | |||
| .output(2, "v", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(apply_ada_max_ds_op_info) | |||
| def _apply_ada_max_ds_tbe(): | |||
| """ApplyAdaMaxD TBE register""" | |||
| return | |||
| @@ -5546,7 +5546,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer): | |||
| return x_type | |||
| class ApplyAdaMax(PrimitiveWithInfer): | |||
| class ApplyAdaMax(Primitive): | |||
| r""" | |||
| Updates relevant entries according to the adamax scheme. | |||
| @@ -5658,45 +5658,6 @@ class ApplyAdaMax(PrimitiveWithInfer): | |||
| """Initialize ApplyAdaMax""" | |||
| self.add_prim_attr('side_effect_mem', True) | |||
| def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, lr_shape, | |||
| beta1_shape, beta2_shape, epsilon_shape, grad_shape): | |||
| validator.check("m_shape", m_shape, "var_shape", var_shape, Rel.EQ, self.name) | |||
| validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name) | |||
| validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name) | |||
| beta1_power_shp_len = len(beta1_power_shape) | |||
| validator.check_int(beta1_power_shp_len, 1, Rel.LE, "beta1 power's rank", self.name) | |||
| if beta1_power_shp_len == 1: | |||
| validator.check_int(beta1_power_shape[0], 1, Rel.EQ, "beta1_power_shape[0]", self.name) | |||
| lr_shp_len = len(lr_shape) | |||
| validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shp_len == 1: | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| beta1_shp_len = len(beta1_shape) | |||
| validator.check_int(beta1_shp_len, 1, Rel.LE, "beta1's rank", self.name) | |||
| if beta1_shp_len == 1: | |||
| validator.check_int(beta1_shape[0], 1, Rel.EQ, "beta1_shape[0]", self.name) | |||
| beta2_shp_len = len(beta2_shape) | |||
| validator.check_int(beta2_shp_len, 1, Rel.LE, "beta2's rank", self.name) | |||
| if beta2_shp_len == 1: | |||
| validator.check_int(beta2_shape[0], 1, Rel.EQ, "beta2_shape[0]", self.name) | |||
| epsilon_shp_len = len(epsilon_shape) | |||
| validator.check_int(epsilon_shp_len, 1, Rel.LE, "epsilon's rank", self.name) | |||
| if epsilon_shp_len == 1: | |||
| validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name) | |||
| return var_shape, m_shape, v_shape | |||
| def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype, | |||
| beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||
| validator.check_scalar_or_tensor_types_same({"beta1_power": beta1_power_dtype}, valid_dtypes, self.name) | |||
| validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name) | |||
| validator.check_scalar_or_tensor_types_same({"beta1": beta1_dtype}, valid_dtypes, self.name) | |||
| validator.check_scalar_or_tensor_types_same({"beta2": beta2_dtype}, valid_dtypes, self.name) | |||
| validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name) | |||
| return var_dtype, m_dtype, v_dtype | |||
| class ApplyAdadelta(PrimitiveWithInfer): | |||
| r""" | |||