Browse Source

!15868 fix diffent error type bug in D and cpu&gpu

From: @simson_wu
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/15868/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
bec6b0cd9c
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      mindspore/core/ops/gather_d.cc

+ 5
- 4
mindspore/core/ops/gather_d.cc View File

@@ -57,16 +57,17 @@ TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
std::set<TypePtr> valid_x_type = {kTensorType};
auto x_type =
CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
std::set<TypePtr> valid_index_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name);
std::set<TypePtr> valid_dim_type = {kInt32, kInt64};
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name);
return x_type;
}
} // namespace
AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
std::set<TypePtr> valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name);
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name);
auto abs = std::make_shared<abstract::AbstractTensor>(GatherDInferType(primitive, input_args),
GatherDInferShape(primitive, input_args));
return abs;


Loading…
Cancel
Save