diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index c215b10b7c..f83387d46c 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -933,7 +933,11 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis auto ret_backend = prim_backend_eval_impl_map.find(prim); if (ret_backend != prim_backend_eval_impl_map.end()) { MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_); - return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list); + auto infer_spec_list = args_spec_list; + if (!ret_backend->second.in_white_list_) { + infer_spec_list = RectifyAbstract(prim, args_spec_list); + } + return ret_backend->second.infer_shape_impl_(nullptr, prim, infer_spec_list); } } MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()