diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index c497992e76..b035f7660e 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -57,16 +57,17 @@ TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector valid_x_type = {kTensorType}; auto x_type = CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); - std::set valid_index_types = {kInt32, kInt64}; - CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name); - std::set valid_dim_type = {kInt32, kInt64}; - CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name); return x_type; } } // namespace AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + std::set valid_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name); + CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name); auto abs = std::make_shared(GatherDInferType(primitive, input_args), GatherDInferShape(primitive, input_args)); return abs;