|
|
|
@@ -302,7 +302,7 @@ std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict &weights, bool need_layout) { |
|
|
|
FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict &weights, bool has_parallel_info) { |
|
|
|
const size_t idx = resource->compile_cache_id(); |
|
|
|
std::string compile_cache_path = GetCompileCachePath(idx); |
|
|
|
auto realpath = Common::CreatePrefixPath(compile_cache_path, true); |
|
|
|
@@ -320,9 +320,9 @@ FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict |
|
|
|
MindIRLoader mindir_loader; |
|
|
|
mindir_loader.set_need_renormalize(false); |
|
|
|
mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights)); |
|
|
|
mindir_loader.set_need_layout(need_layout); |
|
|
|
mindir_loader.set_has_parallel_info(has_parallel_info); |
|
|
|
auto fg = mindir_loader.LoadMindIR(realpath.value()); |
|
|
|
if (need_layout) { |
|
|
|
if (has_parallel_info) { |
|
|
|
resource->set_layout_map(mindir_loader.get_layout_map()); |
|
|
|
} |
|
|
|
return fg; |
|
|
|
@@ -336,14 +336,14 @@ FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &wei |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// Determine whether to load layout information. |
|
|
|
// Determine whether to load parallel information. |
|
|
|
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); |
|
|
|
bool need_layout = false; |
|
|
|
bool has_parallel_info = false; |
|
|
|
if ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL)) { |
|
|
|
need_layout = true; |
|
|
|
has_parallel_info = true; |
|
|
|
} |
|
|
|
// Load the compilation cache file. |
|
|
|
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource, weights, need_layout); |
|
|
|
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource, weights, has_parallel_info); |
|
|
|
if (fg == nullptr) { |
|
|
|
MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions."; |
|
|
|
return nullptr; |
|
|
|
|