From c83ae2f2944a2e9b2a20039a890dd2bf1a0d8063 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Thu, 15 Apr 2021 20:18:02 +0800 Subject: [PATCH] Add logsoftmax C++ infer --- mindspore/core/abstract/infer_functions.h | 2 -- mindspore/core/abstract/prim_others.cc | 16 ------------ mindspore/core/ops/log_softmax.cc | 31 +++++++++++++++-------- mindspore/core/ops/log_softmax.h | 3 --- mindspore/ops/operations/nn_ops.py | 11 +------- 5 files changed, 22 insertions(+), 41 deletions(-) diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index d25c6c51ef..bb2f8e1ca5 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -304,8 +304,6 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index e71d398e9f..ac4686856e 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -545,22 +545,6 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con return std::make_shared(input->element(), shape); } -AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - auto type = args_spec_list[0]->BuildType(); - MS_EXCEPTION_IF_NULL(type); - auto tensor_type = type->cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto value = tensor_type->element(); - auto abstract = std::make_shared(value); - abstract->set_value(value); - return abstract; -} - AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: Ref/Tensor, universal diff --git a/mindspore/core/ops/log_softmax.cc b/mindspore/core/ops/log_softmax.cc index eb31b486f9..cb32039e2e 100644 --- a/mindspore/core/ops/log_softmax.cc +++ b/mindspore/core/ops/log_softmax.cc @@ -34,22 +34,33 @@ void LogSoftmax::Init(const int64_t axis) { this->set_axis(axis); } abstract::ShapePtr LogSoftmaxInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto LogSoftmax_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(LogSoftmax_prim); - auto op_name = LogSoftmax_prim->name(); - auto axis = LogSoftmax_prim->get_axis(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); + auto op_name = primitive->name(); + auto axis = GetValue(primitive->GetAttr(kAxis)); + CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + if (shape_map.empty()) { + // Scalar input, has no shape + return std::make_shared(std::vector()); + } + auto in_shape = shape_map[kShape]; + auto min_shape = shape_map[kMinShape]; + auto max_shape = shape_map[kMaxShape]; auto rank = SizeToLong(in_shape.size()); CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-rank, rank}, op_name); + if (min_shape.size() != 0 && max_shape.size() != 0) { + return std::make_shared(in_shape, min_shape, max_shape); + } return std::make_shared(in_shape); } TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector &input_args) { - if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { - MS_LOG(EXCEPTION) << "nullptr"; - } + MS_EXCEPTION_IF_NULL(prim); + auto op_name = prim->name(); + CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); const std::set valid_types = {kFloat16, kFloat32}; - return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); + return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name); } AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -57,6 +68,6 @@ AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primi return std::make_shared(LogSoftmaxInferType(primitive, input_args), LogSoftmaxInferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_C(kNameLogSoftmax, LogSoftmax); +REGISTER_PRIMITIVE_EVAL_IMPL(LogSoftmax, prim::kPrimLogSoftmax, LogSoftmaxInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/log_softmax.h b/mindspore/core/ops/log_softmax.h index ad6bef4417..4815dbaf07 100644 --- a/mindspore/core/ops/log_softmax.h +++ b/mindspore/core/ops/log_softmax.h @@ -37,9 +37,6 @@ class LogSoftmax : public PrimitiveC { void set_axis(const int64_t axis); int64_t get_axis() const; }; - -AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); using PrimLogSoftmaxPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 17fc4ee8b3..64e32da8a0 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -178,7 +178,7 @@ class Softmax(Primitive): validator.check_value_type("item of axis", item, [int], self.name) -class LogSoftmax(PrimitiveWithInfer): +class LogSoftmax(Primitive): r""" Log Softmax activation function. @@ -220,15 +220,6 @@ class LogSoftmax(PrimitiveWithInfer): def __init__(self, axis=-1): validator.check_value_type("axis", axis, [int], self.name) - def infer_shape(self, logits): - rank = len(logits) - validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name) - return logits - - def infer_dtype(self, logits): - validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name) - return logits - class Softplus(Primitive): r"""