|
|
|
@@ -44,6 +44,7 @@ void Softmax::Init(const int64_t axis) { |
|
|
|
this->set_axis(axis_vec); |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
auto op_name = primitive->name(); |
|
|
|
@@ -74,6 +75,7 @@ TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBas |
|
|
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64}; |
|
|
|
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
|