|
|
|
@@ -46,15 +46,24 @@ void Softmax::Init(const int64_t axis) { |
|
|
|
|
|
|
|
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
auto Softmax_prim = primitive->cast<PrimSoftmaxPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(Softmax_prim); |
|
|
|
auto op_name = Softmax_prim->name(); |
|
|
|
auto axis = Softmax_prim->get_axis(); |
|
|
|
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); |
|
|
|
auto op_name = primitive->name(); |
|
|
|
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); |
|
|
|
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name); |
|
|
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); |
|
|
|
if (shape_map.empty()) { |
|
|
|
// Scalar input, has no shape |
|
|
|
return std::make_shared<abstract::Shape>(std::vector<int64_t>()); |
|
|
|
} |
|
|
|
auto in_shape = shape_map[kShape]; |
|
|
|
auto min_shape = shape_map[kMinShape]; |
|
|
|
auto max_shape = shape_map[kMaxShape]; |
|
|
|
auto rank = SizeToLong(in_shape.size()); |
|
|
|
for (auto &item : axis) { |
|
|
|
CheckAndConvertUtils::CheckInRange<int64_t>("axis", item, kIncludeLeft, {-rank, rank}, op_name); |
|
|
|
} |
|
|
|
if (min_shape.size() != 0 && max_shape.size() != 0) { |
|
|
|
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); |
|
|
|
} |
|
|
|
return std::make_shared<abstract::Shape>(in_shape); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -71,6 +80,7 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv |
|
|
|
return std::make_shared<abstract::AbstractTensor>(SoftMaxInferType(primitive, input_args), |
|
|
|
SoftMaxInferShape(primitive, input_args)->shape()); |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true); |
|
|
|
REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax); |
|
|
|
} // namespace ops |
|
|
|
} // namespace mindspore |