|
|
|
@@ -933,7 +933,11 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis |
|
|
|
auto ret_backend = prim_backend_eval_impl_map.find(prim); |
|
|
|
if (ret_backend != prim_backend_eval_impl_map.end()) { |
|
|
|
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_); |
|
|
|
return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list); |
|
|
|
auto infer_spec_list = args_spec_list; |
|
|
|
if (!ret_backend->second.in_white_list_) { |
|
|
|
infer_spec_list = RectifyAbstract(prim, args_spec_list); |
|
|
|
} |
|
|
|
return ret_backend->second.infer_shape_impl_(nullptr, prim, infer_spec_list); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() |
|
|
|
|