|
|
|
@@ -57,16 +57,17 @@ TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBas |
|
|
|
std::set<TypePtr> valid_x_type = {kTensorType}; |
|
|
|
auto x_type = |
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); |
|
|
|
std::set<TypePtr> valid_index_types = {kInt32, kInt64}; |
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name); |
|
|
|
std::set<TypePtr> valid_dim_type = {kInt32, kInt64}; |
|
|
|
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name); |
|
|
|
return x_type; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
auto prim_name = primitive->name(); |
|
|
|
// check |
|
|
|
std::set<TypePtr> valid_types = {kInt32, kInt64}; |
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name); |
|
|
|
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name); |
|
|
|
auto abs = std::make_shared<abstract::AbstractTensor>(GatherDInferType(primitive, input_args), |
|
|
|
GatherDInferShape(primitive, input_args)); |
|
|
|
return abs; |
|
|
|
|