diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index dd6b225f1f..ee02232295 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -271,6 +271,7 @@ inline const PrimitivePtr kPrimLrn = std::make_shared("LRN"); inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); +inline const PrimitivePtr kPrimLog1p = std::make_shared("Log1p"); inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); inline const PrimitivePtr kPrimDropoutGrad = std::make_shared("DropoutGrad"); @@ -287,6 +288,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared("Elu"); inline const PrimitivePtr kPrimRelu6 = std::make_shared("ReLU6"); inline const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); inline const PrimitivePtr kPrimPRelu = std::make_shared("PReLU"); +inline const PrimitivePtr kPrimSoftplus = std::make_shared("Softplus"); inline const PrimitivePtr kPrimZeros = std::make_shared("Zeros"); inline const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); inline const PrimitivePtr kPrimOnesLike = std::make_shared("OnesLike"); diff --git a/mindspore/core/ops/broadcast_to.cc b/mindspore/core/ops/broadcast_to.cc index 85f038867b..44c01bd523 100644 --- a/mindspore/core/ops/broadcast_to.cc +++ b/mindspore/core/ops/broadcast_to.cc @@ -23,13 +23,12 @@ namespace { abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto broad_cast_to = primitive->cast(); - MS_EXCEPTION_IF_NULL(broad_cast_to); - auto prim_name = broad_cast_to->name(); + auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto input_x = broad_cast_to->get_shape(); + auto value_ptr = primitive->GetAttr(kShape); + auto input_x = GetValue>(value_ptr); int64_t outer_dim_offset = input_x.size() - x_shape.size(); - CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name); + CheckAndConvertUtils::Check("x shape", x_shape.size(), kLessEqual, "input_x", input_x.size(), prim_name); bool flag = true; if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) { flag = false; @@ -49,7 +48,6 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, } } } - std::reverse(input_x.begin(), input_x.end()); return std::make_shared(input_x); } @@ -78,8 +76,8 @@ std::vector BroadcastTo::get_shape() const { AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { return std::make_shared(BroadcastToInferType(primitive, input_args), - BroadcastToInferShape(primitive, input_args)->shape()); + BroadcastToInferShape(primitive, input_args)); } -REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo); +REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index 05d95697de..4d14e3d66e 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -69,6 +69,6 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv GatherDInferShape(primitive, input_args)); return abs; } -REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false); +REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/log1p.cc b/mindspore/core/ops/log1p.cc new file mode 100644 index 0000000000..babb09ad0d --- /dev/null +++ b/mindspore/core/ops/log1p.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/log1p.h" +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +// log1p +namespace { +abstract::ShapePtr Log1pInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto in_shape = shape_map[kShape]; + auto min_shape = shape_map[kMinShape]; + auto max_shape = shape_map[kMaxShape]; + return std::make_shared(in_shape, min_shape, max_shape); +} + +TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + // check + std::set valid_index_types = {kFloat16, kFloat32}; + auto x_type = + CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name); + return x_type; +} +} // namespace + +AbstractBasePtr Log1pInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto abs = std::make_shared(Log1pInferType(primitive, input_args), + Log1pInferShape(primitive, input_args)); + return abs; +} +REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/log1p.h b/mindspore/core/ops/log1p.h new file mode 100644 index 0000000000..0483c19e85 --- /dev/null +++ b/mindspore/core/ops/log1p.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_LOG1P_H_ +#define MINDSPORE_CORE_OPS_LOG1P_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameLog1p = "Log1p"; +class Log1p : public PrimitiveC { + public: + Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); } + ~Log1p() = default; + MS_DECLARE_PARENT(Log1p, PrimitiveC); + void Init() {} +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_LOG1P_H_ diff --git a/mindspore/core/ops/softplus.cc b/mindspore/core/ops/softplus.cc new file mode 100644 index 0000000000..f76964b729 --- /dev/null +++ b/mindspore/core/ops/softplus.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/softplus.h" +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +// softplus +namespace { +abstract::ShapePtr SoftplusInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto in_shape = shape_map[kShape]; + auto min_shape = shape_map[kMinShape]; + auto max_shape = shape_map[kMaxShape]; + return std::make_shared(in_shape, min_shape, max_shape); +} + +TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + // check + std::set valid_index_types = {kFloat16, kFloat32, kFloat64}; + auto x_type = + CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name); + return x_type; +} +} // namespace + +AbstractBasePtr SoftplusInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto abs = std::make_shared(SoftplusInferType(primitive, input_args), + SoftplusInferShape(primitive, input_args)); + return abs; +} +REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/softplus.h b/mindspore/core/ops/softplus.h new file mode 100644 index 0000000000..cc578e47e4 --- /dev/null +++ b/mindspore/core/ops/softplus.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SOFTPLUS_H_ +#define MINDSPORE_CORE_OPS_SOFTPLUS_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSoftplus = "Softplus"; +class Softplus : public PrimitiveC { + public: + Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); } + ~Softplus() = default; + MS_DECLARE_PARENT(Softplus, PrimitiveC); + void Init() {} +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SOFTPLUS_H_ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index df783f810f..314af566e5 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1306,7 +1306,7 @@ class Ones(PrimitiveWithInfer): return out -class Zeros(PrimitiveWithInfer): +class Zeros(Primitive): r""" Creates a tensor filled with value zeros. @@ -4570,7 +4570,7 @@ class BatchToSpaceND(PrimitiveWithInfer): return out_shape -class BroadcastTo(PrimitiveWithInfer): +class BroadcastTo(Primitive): """ Broadcasts input tensor to a given shape. @@ -4629,34 +4629,6 @@ class BroadcastTo(PrimitiveWithInfer): validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) self.shape = shape - def infer_shape(self, x_shape): - validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) - - reversed_x_shape = tuple(reversed(x_shape)) - reversed_filtered_target = [] - for i, v in enumerate(tuple(reversed(self.shape))): - if v == -1: - if i >= len(reversed_x_shape): - raise ValueError("-1 is not valid in a leading, non-existing dimension") - - reversed_filtered_target.append(reversed_x_shape[i]) - else: - reversed_filtered_target.append(v) - - self.shape = tuple(reversed(reversed_filtered_target)) - self.add_prim_attr('shape', self.shape) - - for i, v in enumerate(reversed_x_shape): - if v not in (reversed_filtered_target[i], 1): - raise ValueError(f"Not supported shapes for broadcast, " - f"x_shape: {tuple(x_shape)}, target shape {self.shape}.") - - return self.shape - - def infer_dtype(self, x_dtype): - validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) - return x_dtype - class Meshgrid(PrimitiveWithInfer): """ @@ -5121,7 +5093,7 @@ class EmbeddingLookup(PrimitiveWithCheck): raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp)) -class GatherD(PrimitiveWithInfer): +class GatherD(Primitive): """ Gathers values along an axis specified by dim. diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 06b0bb2f60..8242e3a919 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -20,7 +20,7 @@ from mindspore import context from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype -from ..primitive import prim_attr_register, PrimitiveWithInfer +from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer def _check_mode(class_name): @@ -50,7 +50,7 @@ def _check_summary_param(name, value, class_name): SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None} -class ScalarSummary(PrimitiveWithInfer): +class ScalarSummary(Primitive): """ Outputs a scalar to a protocol buffer through a scalar summary operator. @@ -141,7 +141,7 @@ class ImageSummary(PrimitiveWithInfer): return SUMMARY_RETURN_VALUE -class TensorSummary(PrimitiveWithInfer): +class TensorSummary(Primitive): """ Outputs a tensor to a protocol buffer through a tensor summary operator. diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 27af781d8c..98e3efff3c 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -26,7 +26,7 @@ from ...common import dtype as mstype from ...common.tensor import Tensor from ...common._decorator import deprecated from .._utils import get_broadcast_shape -from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op +from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op def _infer_shape_reduce(x, axis, keep_dims, prim_name): @@ -1873,7 +1873,7 @@ class Log(PrimitiveWithInfer): return None -class Log1p(PrimitiveWithInfer): +class Log1p(Primitive): """ Returns the natural logarithm of one plus the input tensor element-wise. @@ -1901,14 +1901,6 @@ class Log1p(PrimitiveWithInfer): def __init__(self): self.init_prim_io_names(inputs=['x'], outputs=['y']) - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - validator.check_subclass("x", x_dtype, mstype.tensor, self.name) - validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) - return x_dtype - class Erf(PrimitiveWithInfer): r""" diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 63a5cfc8e7..fb91e67781 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -230,7 +230,7 @@ class LogSoftmax(PrimitiveWithInfer): return logits -class Softplus(PrimitiveWithInfer): +class Softplus(Primitive): r""" Softplus activation function. @@ -267,13 +267,6 @@ class Softplus(PrimitiveWithInfer): """Initialize Softplus""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name) - return x_dtype - class Softsign(PrimitiveWithInfer): r"""