| @@ -433,6 +433,7 @@ inline const PrimitivePtr kPrimSparseApplyRMSProp = std::make_shared<Primitive>( | |||
| inline const PrimitivePtr kPrimApplyKerasMomentum = std::make_shared<Primitive>("ApplyKerasMomentum"); | |||
| inline const PrimitivePtr kPrimLARSUpdate = std::make_shared<Primitive>("LARSUpdate"); | |||
| inline const PrimitivePtr kPrimApplyAddSign = std::make_shared<Primitive>("ApplyAddSign"); | |||
| inline const PrimitivePtr kPrimApplyAdagrad = std::make_shared<Primitive>("ApplyAdagrad"); | |||
| // Comm ops | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| @@ -0,0 +1,93 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/apply_adagrad.h" | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "utils/tensor_construct_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::TupleShapePtr ApplyAdagradInferShape(const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| const int64_t kInputNum = 4; | |||
| (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); | |||
| auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; | |||
| auto var_shape_ptr = input_args[kInputIndex0]->BuildShape(); | |||
| auto accum_shape_ptr = input_args[kInputIndex1]->BuildShape(); | |||
| auto grad_shape_ptr = input_args[kInputIndex3]->BuildShape(); | |||
| // lr should be scalar or size equal with 1 | |||
| (void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kLessEqual, 1, prim_name); | |||
| if (lr_shape.size() == 1) { | |||
| (void)CheckAndConvertUtils::CheckInteger("lr_shape's first rank must be 1", lr_shape[0], kEqual, 1, prim_name); | |||
| } | |||
| if (grad_shape_ptr->IsDynamic()) { | |||
| return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape_ptr, accum_shape_ptr}); | |||
| } | |||
| // var, accum and grad must have the same shape | |||
| std::map<std::string, abstract::BaseShapePtr> same_shape_args_map; | |||
| same_shape_args_map.insert({"accum", accum_shape_ptr}); | |||
| same_shape_args_map.insert({"grad", grad_shape_ptr}); | |||
| for (auto &elem : same_shape_args_map) { | |||
| if (*elem.second != *var_shape_ptr) { | |||
| MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString() | |||
| << " are not consistent with var shape " << var_shape_ptr->ToString(); | |||
| } | |||
| } | |||
| return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape_ptr, accum_shape_ptr}); | |||
| } | |||
| TuplePtr ApplyAdagradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| const int64_t kInputNum = 4; | |||
| (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); | |||
| auto var_type = input_args[kInputIndex0]->BuildType(); | |||
| auto accum_type = input_args[kInputIndex1]->BuildType(); | |||
| auto lr_type = input_args[kInputIndex2]->BuildType(); | |||
| auto grad_type = input_args[kInputIndex3]->BuildType(); | |||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | |||
| // var, accum and grad must have the same type | |||
| std::map<std::string, TypePtr> args; | |||
| args.insert({"var", var_type}); | |||
| args.insert({"accum", accum_type}); | |||
| args.insert({"grad", grad_type}); | |||
| (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); | |||
| // lr type must be valid | |||
| std::map<std::string, TypePtr> args_lr; | |||
| args_lr.insert({"lr", lr_type}); | |||
| (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name); | |||
| return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type}); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto infer_type = ApplyAdagradInferType(primitive, input_args); | |||
| auto infer_shape = ApplyAdagradInferShape(primitive, input_args); | |||
| return abstract::MakeAbstract(infer_shape, infer_type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdagrad, prim::kPrimApplyAdagrad, ApplyAdagradInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_APPLY_ADAGRAD_H_ | |||
| #define MINDSPORE_CORE_OPS_APPLY_ADAGRAD_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameApplyAdagrad = "ApplyAdagrad"; | |||
| class ApplyAdagrad : public PrimitiveC { | |||
| public: | |||
| ApplyAdagrad() : PrimitiveC(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); } | |||
| ~ApplyAdagrad() = default; | |||
| MS_DECLARE_PARENT(ApplyAdagrad, PrimitiveC); | |||
| }; | |||
| AbstractBasePtr ApplyAdagradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using kPrimApplyAdagradPtr = std::shared_ptr<ApplyAdagrad>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_APPLY_ADAGRAD_D_H_ | |||
| @@ -44,6 +44,7 @@ from .apply_adam_ds import _apply_adam_ds_tbe | |||
| from .apply_ada_max import _apply_ada_max_tbe | |||
| from .apply_adadelta import _apply_adadelta_tbe | |||
| from .apply_adagrad import _apply_adagrad_tbe | |||
| from .apply_adagrad_ds import _apply_adagrad_ds_tbe | |||
| from .apply_adagrad_v2 import _apply_adagrad_v2_tbe | |||
| from .apply_adagrad_v2_ds import _apply_adagrad_v2_ds_tbe | |||
| from .apply_adagrad_d_a import _apply_adagrad_d_a_tbe | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ApplyAdagradD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| apply_adagrad_d_op_info = TBERegOp("ApplyAdagrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("apply_adagrad_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("apply_adagrad_d") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("update_slots", "optional", "bool", "true,false", "true") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "accum", False, "required", "all") \ | |||
| .input(2, "lr", False, "required", "all") \ | |||
| .input(3, "grad", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "accum", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | |||
| DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | |||
| DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | |||
| DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(apply_adagrad_d_op_info) | |||
| def _apply_adagrad_ds_tbe(): | |||
| """ApplyAdagradD TBE register""" | |||
| return | |||
| @@ -5636,7 +5636,7 @@ class ApplyAdadelta(PrimitiveWithInfer): | |||
| return var_dtype, accum_dtype, accum_update_dtype | |||
| class ApplyAdagrad(PrimitiveWithInfer): | |||
| class ApplyAdagrad(Primitive): | |||
| r""" | |||
| Updates relevant entries according to the adagrad scheme. | |||
| It has been proposed in Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. | |||
| @@ -5716,23 +5716,6 @@ class ApplyAdagrad(PrimitiveWithInfer): | |||
| validator.check_value_type("update_slots", update_slots, [bool], self.name) | |||
| self.add_prim_attr('side_effect_mem', True) | |||
| def infer_shape(self, var_shape, accum_shape, lr_shape, grad_shape): | |||
| validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name) | |||
| lr_shp_len = len(lr_shape) | |||
| validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name) | |||
| if lr_shp_len == 1: | |||
| validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name) | |||
| return var_shape, accum_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): | |||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||
| validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, valid_dtypes, self.name) | |||
| return var_dtype, accum_dtype | |||
| class ApplyAdagradV2(Primitive): | |||
| r""" | |||
| Updates relevant entries according to the adagradv2 scheme. | |||