Browse Source

update mindspore/core/ops/apply_ada_max.cc.

tags/v1.6.0
闻泽恺 fakeen 4 years ago
parent
commit
d9155b264e
6 changed files with 281 additions and 40 deletions
  1. +1
    -0
      mindspore/core/base/core_ops.h
  2. +164
    -0
      mindspore/core/ops/apply_ada_max.cc
  3. +45
    -0
      mindspore/core/ops/apply_ada_max.h
  4. +1
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  5. +69
    -0
      mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py
  6. +1
    -40
      mindspore/ops/operations/nn_ops.py

+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -273,6 +273,7 @@ inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shared<Primitive
// NN
inline const PrimitivePtr kPrimCeLU = std::make_shared<Primitive>("CeLU");
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
inline const PrimitivePtr kPrimApplyAdaMax = std::make_shared<Primitive>("ApplyAdaMax");
inline const PrimitivePtr kPrimAudioSpectrogram = std::make_shared<Primitive>("AudioSpectrogram");
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
inline const PrimitivePtr kPrimCrop = std::make_shared<Primitive>("Crop");


+ 164
- 0
mindspore/core/ops/apply_ada_max.cc View File

@@ -0,0 +1,164 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/apply_ada_max.h"
#include <algorithm>
#include <set>
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr ApplyAdaMaxInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 9;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto prim_name = primitive->name();
auto var_shape = input_args[kInputIndex0]->BuildShape();
auto m_shape = input_args[kInputIndex1]->BuildShape();
auto v_shape = input_args[kInputIndex2]->BuildShape();
auto var_shape_ptr = var_shape->cast<abstract::ShapePtr>();
auto m_shape_ptr = m_shape->cast<abstract::ShapePtr>();
auto v_shape_ptr = v_shape->cast<abstract::ShapePtr>();
auto beta1_power_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto beta1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
auto beta2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
auto epsilon_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
auto grad_shape = input_args[kInputIndex8]->BuildShape();
auto grad_shape_ptr = grad_shape->cast<abstract::ShapePtr>();
// beta1_power,lr,beta1,beta2,epsilon must be scalar
const int64_t kInputShape = 1;
(void)CheckAndConvertUtils::CheckInteger("beta1 power's rank", beta1_power_shape.size(), kLessEqual, kInputShape,
prim_name);
if (beta1_power_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta1_power_shape[0]", beta1_power_shape.size(), kEqual, kInputShape,
prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("lr's rank", lr_shape.size(), kLessEqual, kInputShape, prim_name);
if (lr_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0]", lr_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("beta1's rank", beta1_shape.size(), kLessEqual, kInputShape, prim_name);
if (beta1_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta1_shape[0]", beta1_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("beta2's rank", beta2_shape.size(), kLessEqual, kInputShape, prim_name);
if (beta2_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta2_shape[0]", beta2_shape.size(), kEqual, kInputShape, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("epsilon's rank", epsilon_shape.size(), kLessEqual, kInputShape, prim_name);
if (epsilon_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape[0]", epsilon_shape.size(), kEqual, kInputShape, prim_name);
}
// var, m,v and grad must have the same shape
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
same_shape_args_map.insert({"m", m_shape});
same_shape_args_map.insert({"v", v_shape});
same_shape_args_map.insert({"grad", grad_shape});
if (!var_shape_ptr->IsDynamic() && !m_shape_ptr->IsDynamic()) {
if (*m_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg m shape " << m_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
if (!v_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
if (*v_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg v shape " << v_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
if (!grad_shape_ptr->IsDynamic() && !var_shape_ptr->IsDynamic()) {
if (*grad_shape != *var_shape) {
MS_EXCEPTION(ValueError) << primitive->name() << " evaluator arg grad shape " << grad_shape->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape});
}
TuplePtr ApplyAdaMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t kInputNum = 9;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_type = input_args[kInputIndex0]->BuildType();
auto m_type = input_args[kInputIndex1]->BuildType();
auto v_type = input_args[kInputIndex2]->BuildType();
auto beta1_power_type = input_args[kInputIndex3]->BuildType();
auto lr_type = input_args[kInputIndex4]->BuildType();
auto beta1_type = input_args[kInputIndex5]->BuildType();
auto beta2_type = input_args[kInputIndex6]->BuildType();
auto epsilon_type = input_args[kInputIndex7]->BuildType();
auto grad_type = input_args[kInputIndex8]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// m v grad must have the same type as var
std::map<std::string, TypePtr> args;
(void)args.insert({"var_type", var_type});
(void)args.insert({"m_type", m_type});
(void)args.insert({"v_type", v_type});
(void)args.insert({"grad_type", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
std::map<std::string, TypePtr> args_beta1_power;
std::map<std::string, TypePtr> args_lr;
std::map<std::string, TypePtr> args_beta1;
std::map<std::string, TypePtr> args_beta2;
std::map<std::string, TypePtr> args_epsilon;
(void)args_beta1_power.insert({"beta1_power_type", beta1_power_type});
(void)args_lr.insert({"lr_type", lr_type});
(void)args_beta1.insert({"beta1_type", beta1_type});
(void)args_beta2.insert({"beta2_type", beta2_type});
(void)args_epsilon.insert({"epsilon_type", epsilon_type});
// beta1_power,lr,beta1,beta2,epsilon must be a scalar or zero dimension tensor type
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1_power, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta1, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_beta2, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_epsilon, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type});
}
} // namespace
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = ApplyAdaMaxInferType(primitive, input_args);
auto infer_shape = ApplyAdaMaxInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdaMax, prim::kPrimApplyAdaMax, ApplyAdaMaxInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 45
- 0
mindspore/core/ops/apply_ada_max.h View File

@@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
#define MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdaMax = "ApplyAdaMax";
class ApplyAdaMax : public PrimitiveC {
public:
ApplyAdaMax() : PrimitiveC(kNameApplyAdaMax) {
InitIOName({"var", "m", "v", "beta1_power", "lr", "beta1", "beta2", "epsilon", "grad"}, {"var", "m", "v"});
}
~ApplyAdaMax() = default;
MS_DECLARE_PARENT(ApplyAdaMax, PrimitiveC);
};
AbstractBasePtr ApplyAdaMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimApplyAdaMaxPtr = std::shared_ptr<ApplyAdaMax>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_APPLY_ADA_MAX_H_

+ 1
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -35,6 +35,7 @@ from .apply_keras_momentum import _apply_keras_momentum_tbe
from .apply_momentum import _apply_momentum_tbe
from .apply_adam import _apply_adam_tbe
from .apply_ada_max import _apply_ada_max_tbe
from .apply_ada_max_ds import _apply_ada_max_ds_tbe
from .apply_adadelta import _apply_adadelta_tbe
from .apply_adagrad import _apply_adagrad_tbe
from .apply_adagrad_v2 import _apply_adagrad_v2_tbe


+ 69
- 0
mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py View File

@@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyAdaMaxD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_ada_max_ds_op_info = TBERegOp("ApplyAdaMax") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_ada_max_d.so") \
.compute_cost(10) \
.kernel_name("apply_ada_max_d") \
.partial_flag(True) \
.dynamic_shape(True)\
.input(0, "var", False, "required", "all") \
.input(1, "m", False, "required", "all") \
.input(2, "v", False, "required", "all") \
.input(3, "beta1_power", False, "required", "all") \
.input(4, "lr", False, "required", "all") \
.input(5, "beta1", False, "required", "all") \
.input(6, "beta2", False, "required", "all") \
.input(7, "epsilon", False, "required", "all") \
.input(8, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "m", False, "required", "all") \
.output(2, "v", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(apply_ada_max_ds_op_info)
def _apply_ada_max_ds_tbe():
"""ApplyAdaMaxD TBE register"""
return

+ 1
- 40
mindspore/ops/operations/nn_ops.py View File

@@ -5546,7 +5546,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
return x_type


class ApplyAdaMax(PrimitiveWithInfer):
class ApplyAdaMax(Primitive):
r"""
Updates relevant entries according to the adamax scheme.

@@ -5658,45 +5658,6 @@ class ApplyAdaMax(PrimitiveWithInfer):
"""Initialize ApplyAdaMax"""
self.add_prim_attr('side_effect_mem', True)

def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, lr_shape,
beta1_shape, beta2_shape, epsilon_shape, grad_shape):
validator.check("m_shape", m_shape, "var_shape", var_shape, Rel.EQ, self.name)
validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name)
validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name)
beta1_power_shp_len = len(beta1_power_shape)
validator.check_int(beta1_power_shp_len, 1, Rel.LE, "beta1 power's rank", self.name)
if beta1_power_shp_len == 1:
validator.check_int(beta1_power_shape[0], 1, Rel.EQ, "beta1_power_shape[0]", self.name)
lr_shp_len = len(lr_shape)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1:
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
beta1_shp_len = len(beta1_shape)
validator.check_int(beta1_shp_len, 1, Rel.LE, "beta1's rank", self.name)
if beta1_shp_len == 1:
validator.check_int(beta1_shape[0], 1, Rel.EQ, "beta1_shape[0]", self.name)
beta2_shp_len = len(beta2_shape)
validator.check_int(beta2_shp_len, 1, Rel.LE, "beta2's rank", self.name)
if beta2_shp_len == 1:
validator.check_int(beta2_shape[0], 1, Rel.EQ, "beta2_shape[0]", self.name)
epsilon_shp_len = len(epsilon_shape)
validator.check_int(epsilon_shp_len, 1, Rel.LE, "epsilon's rank", self.name)
if epsilon_shp_len == 1:
validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name)
return var_shape, m_shape, v_shape

def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
valid_dtypes = [mstype.float16, mstype.float32]
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta1_power": beta1_power_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta1": beta1_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta2": beta2_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype, v_dtype


class ApplyAdadelta(PrimitiveWithInfer):
r"""


Loading…
Cancel
Save