diff --git a/mindspore/core/ops/softmax.cc b/mindspore/core/ops/softmax.cc index 7e76bf0872..0df32d9275 100644 --- a/mindspore/core/ops/softmax.cc +++ b/mindspore/core/ops/softmax.cc @@ -46,15 +46,24 @@ void Softmax::Init(const int64_t axis) { abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto Softmax_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(Softmax_prim); - auto op_name = Softmax_prim->name(); - auto axis = Softmax_prim->get_axis(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); + auto op_name = primitive->name(); + auto axis = GetValue>(primitive->GetAttr(kAxis)); + (void)CheckAndConvertUtils::CheckValue("length of axis", axis.size(), kGreaterEqual, 1, op_name); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + if (shape_map.empty()) { + // Scalar input, has no shape + return std::make_shared(std::vector()); + } + auto in_shape = shape_map[kShape]; + auto min_shape = shape_map[kMinShape]; + auto max_shape = shape_map[kMaxShape]; auto rank = SizeToLong(in_shape.size()); for (auto &item : axis) { CheckAndConvertUtils::CheckInRange("axis", item, kIncludeLeft, {-rank, rank}, op_name); } + if (min_shape.size() != 0 && max_shape.size() != 0) { + return std::make_shared(in_shape, min_shape, max_shape); + } return std::make_shared(in_shape); } @@ -71,6 +80,7 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv return std::make_shared(SoftMaxInferType(primitive, input_args), SoftMaxInferShape(primitive, input_args)->shape()); } +REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true); REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax); } // namespace ops } // namespace mindspore diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 659c1a8c71..cd1f2239e6 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -177,17 +177,6 @@ class Softmax(PrimitiveWithInfer): for item in self.axis: validator.check_value_type("item of axis", item, [int], self.name) - def infer_shape(self, logits): - validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name) - rank = len(logits) - for axis_v in self.axis: - validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) - return logits - - def infer_dtype(self, logits): - validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name) - return logits - class LogSoftmax(PrimitiveWithInfer): r"""