|
|
|
@@ -34,22 +34,33 @@ void LogSoftmax::Init(const int64_t axis) { this->set_axis(axis); } |
|
|
|
|
|
|
|
abstract::ShapePtr LogSoftmaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
auto LogSoftmax_prim = primitive->cast<PrimLogSoftmaxPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(LogSoftmax_prim); |
|
|
|
auto op_name = LogSoftmax_prim->name(); |
|
|
|
auto axis = LogSoftmax_prim->get_axis(); |
|
|
|
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); |
|
|
|
auto op_name = primitive->name(); |
|
|
|
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); |
|
|
|
CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name); |
|
|
|
MS_EXCEPTION_IF_NULL(input_args[0]); |
|
|
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); |
|
|
|
if (shape_map.empty()) { |
|
|
|
// Scalar input, has no shape |
|
|
|
return std::make_shared<abstract::Shape>(std::vector<int64_t>()); |
|
|
|
} |
|
|
|
auto in_shape = shape_map[kShape]; |
|
|
|
auto min_shape = shape_map[kMinShape]; |
|
|
|
auto max_shape = shape_map[kMaxShape]; |
|
|
|
auto rank = SizeToLong(in_shape.size()); |
|
|
|
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-rank, rank}, op_name); |
|
|
|
if (min_shape.size() != 0 && max_shape.size() != 0) { |
|
|
|
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); |
|
|
|
} |
|
|
|
return std::make_shared<abstract::Shape>(in_shape); |
|
|
|
} |
|
|
|
|
|
|
|
TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { |
|
|
|
MS_LOG(EXCEPTION) << "nullptr"; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
auto op_name = prim->name(); |
|
|
|
CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name); |
|
|
|
MS_EXCEPTION_IF_NULL(input_args[0]); |
|
|
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; |
|
|
|
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); |
|
|
|
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
@@ -57,6 +68,6 @@ AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primi |
|
|
|
return std::make_shared<abstract::AbstractTensor>(LogSoftmaxInferType(primitive, input_args), |
|
|
|
LogSoftmaxInferShape(primitive, input_args)->shape()); |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_C(kNameLogSoftmax, LogSoftmax); |
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(LogSoftmax, prim::kPrimLogSoftmax, LogSoftmaxInfer, nullptr, true); |
|
|
|
} // namespace ops |
|
|
|
} // namespace mindspore |