Browse Source

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD

[feat] [assistant] [I48OBG] add new operator ApplyAdamD
tags/v1.6.0
windhxs 4 years ago
parent
commit
7fe548fc7c
5 changed files with 140 additions and 43 deletions
  1. +57
    -25
      mindspore/core/ops/adam.cc
  2. +1
    -0
      mindspore/core/ops/adam.h
  3. +1
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  4. +80
    -0
      mindspore/ops/_op_impl/tbe/apply_adam_ds.py
  5. +1
    -18
      mindspore/ops/operations/nn_ops.py

+ 57
- 25
mindspore/core/ops/adam.cc View File

@@ -15,41 +15,68 @@
*/

#include "ops/adam.h"

#include <set>

#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
namespace {
abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
TuplePtr AdamInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 10;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);

// infer shape
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShapeTrack())[kShape];
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);

// infer type
auto var_type = input_args[kInputIndex0]->BuildType();
auto m_type = input_args[kInputIndex1]->BuildType();
auto v_type = input_args[kInputIndex2]->BuildType();
auto beta1_power_type = input_args[kInputIndex3]->BuildType();
auto beta2_power_type = input_args[kInputIndex4]->BuildType();
auto lr_type = input_args[kInputIndex5]->BuildType();
auto beta1_type = input_args[kInputIndex6]->BuildType();
auto beta2_type = input_args[kInputIndex7]->BuildType();
auto epsilon_type = input_args[kInputIndex8]->BuildType();
auto grad_type = input_args[kInputIndex9]->BuildType();
auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape);
auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape);
AbstractBasePtrList output = {output0, output1, output2};
return std::make_shared<abstract::AbstractTuple>(output);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("var", var_type);
type_dict.emplace("m", m_type);
type_dict.emplace("v", v_type);
type_dict.emplace("grad", grad_type);
std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
(void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name);
std::map<std::string, TypePtr> type_dict1;
type_dict1.emplace("beta1_power", beta1_power_type);
type_dict1.emplace("beta2_power", beta2_power_type);
type_dict1.emplace("lr", lr_type);
type_dict1.emplace("beta1", beta1_type);
type_dict1.emplace("beta2", beta2_type);
type_dict1.emplace("epsilon", epsilon_type);
std::set<TypePtr> float_set = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type});
}
abstract::TupleShapePtr AdamInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto var_shape_ptr = input_args[kInputIndex0]->BuildShape();
auto m_shape_ptr = input_args[kInputIndex1]->BuildShape();
auto v_shape_ptr = input_args[kInputIndex2]->BuildShape();
auto grad_shape_ptr = input_args[kInputIndex9]->BuildShape();
if (var_shape_ptr->IsDynamic() || m_shape_ptr->IsDynamic() || v_shape_ptr->IsDynamic() ||
grad_shape_ptr->IsDynamic()) {
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
}
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->BuildShape())[kShape];
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
}
} // namespace
void Adam::Init(const bool use_locking, const bool use_nesterov) {
@@ -73,8 +100,13 @@ bool Adam::get_use_nesterov() const {

AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args));
auto prim_name = primitive->name();
const int64_t kInputNum = 10;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto infer_shape = AdamInferShape(primitive, input_args);
auto infer_type = AdamInferType(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_C(kNameAdam, Adam);
REGISTER_PRIMITIVE_EVAL_IMPL(Adam, prim::kPrimAdam, AdamInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 1
- 0
mindspore/core/ops/adam.h View File

@@ -53,6 +53,7 @@ class MS_CORE_API Adam : public PrimitiveC {
};
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimAdamPtr = std::shared_ptr<Adam>;
} // namespace ops
} // namespace mindspore



+ 1
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -39,6 +39,7 @@ from .apply_ftrl import _apply_ftrl_tbe
from .apply_keras_momentum import _apply_keras_momentum_tbe
from .apply_momentum import _apply_momentum_tbe
from .apply_adam import _apply_adam_tbe
from .apply_adam_ds import _apply_adam_ds_tbe
from .apply_ada_max import _apply_ada_max_tbe
from .apply_adadelta import _apply_adadelta_tbe
from .apply_adagrad import _apply_adagrad_tbe


+ 80
- 0
mindspore/ops/_op_impl/tbe/apply_adam_ds.py View File

@@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ApplyAdam op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

apply_adam_ds_op_info = TBERegOp("Adam") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_adam.so") \
.compute_cost(10) \
.kernel_name("apply_adam_d") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "m", False, "required", "all") \
.input(2, "v", False, "required", "all") \
.input(3, "beta1_power", False, "required", "all") \
.input(4, "beta2_power", False, "required", "all") \
.input(5, "lr", False, "required", "all") \
.input(6, "beta1", False, "required", "all") \
.input(7, "beta2", False, "required", "all") \
.input(8, "epsilon", False, "required", "all") \
.input(9, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "m", False, "required", "all") \
.output(2, "v", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.get_op_info()


@op_info_register(apply_adam_ds_op_info)
def _apply_adam_ds_tbe():
"""ApplyAdam TBE register"""
return

+ 1
- 18
mindspore/ops/operations/nn_ops.py View File

@@ -4368,7 +4368,7 @@ class ROIAlign(PrimitiveWithInfer):
return inputs_type


class Adam(PrimitiveWithInfer):
class Adam(Primitive):
r"""
Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.

@@ -4464,23 +4464,6 @@ class Adam(PrimitiveWithInfer):
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
self.add_prim_attr('side_effect_mem', True)

def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
beta1_shape, beta2_shape, epsilon_shape, grad_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
return var_shape, m_shape, v_shape

def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)

args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype


class AdamWeightDecay(PrimitiveWithInfer):
r"""


Loading…
Cancel
Save