Browse Source

Fix some coding problem for compile cache

tags/v1.6.0
l00591931 4 years ago
parent
commit
2cd17bfdc6
5 changed files with 19 additions and 19 deletions
  1. +7
    -7
      mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc
  2. +7
    -7
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  3. +1
    -1
      mindspore/ccsrc/pipeline/jit/resource.h
  4. +1
    -1
      mindspore/core/load_mindir/load_model.cc
  5. +3
    -3
      mindspore/core/load_mindir/load_model.h

+ 7
- 7
mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc View File

@@ -39,9 +39,9 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
if (tensor_layout == nullptr) {
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
} else {
auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array();
const auto &device_arrangement = tensor_layout->device_arrangement().array();
const auto &tensor_map = tensor_layout->tensor_map().array();
const auto &slice_shape = tensor_layout->slice_shape().array();
int64_t field_size = tensor_layout->get_field_size();
bool uniform_split = tensor_layout->uniform_split();
const std::string &opt_shard_group = tensor_layout->opt_shard_group();
@@ -56,13 +56,13 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {

py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
py::dict dict;
auto layout_map = resource->get_layout_map();
const auto &layout_map = resource->get_layout_map();
for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
auto name = iter->first;
auto layout = iter->second;
auto device_arrangement = layout->get_device_arrangement();
auto tensor_map = layout->get_tensor_map();
auto slice_shape = layout->get_slice_shape();
const auto &device_arrangement = layout->get_device_arrangement();
const auto &tensor_map = layout->get_tensor_map();
const auto &slice_shape = layout->get_slice_shape();
int64_t field_size = layout->get_field_size();
bool uniform_split = layout->get_uniform_split();
const std::string &opt_shard_group = layout->get_opt_shard_group();


+ 7
- 7
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -302,7 +302,7 @@ std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
return ret;
}

FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict &weights, bool need_layout) {
FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict &weights, bool has_parallel_info) {
const size_t idx = resource->compile_cache_id();
std::string compile_cache_path = GetCompileCachePath(idx);
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
@@ -320,9 +320,9 @@ FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict
MindIRLoader mindir_loader;
mindir_loader.set_need_renormalize(false);
mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
mindir_loader.set_need_layout(need_layout);
mindir_loader.set_has_parallel_info(has_parallel_info);
auto fg = mindir_loader.LoadMindIR(realpath.value());
if (need_layout) {
if (has_parallel_info) {
resource->set_layout_map(mindir_loader.get_layout_map());
}
return fg;
@@ -336,14 +336,14 @@ FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &wei
return nullptr;
}

// Determine whether to load layout information.
// Determine whether to load parallel information.
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
bool need_layout = false;
bool has_parallel_info = false;
if ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL)) {
need_layout = true;
has_parallel_info = true;
}
// Load the compilation cache file.
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource, weights, need_layout);
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource, weights, has_parallel_info);
if (fg == nullptr) {
MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
return nullptr;


+ 1
- 1
mindspore/ccsrc/pipeline/jit/resource.h View File

@@ -89,7 +89,7 @@ class Resource : public ResourceBase {
int64_t loop_size() { return loop_size_; }

void set_layout_map(const LayoutMap &layout_map) { layout_map_ = layout_map; }
const LayoutMap get_layout_map() const { return layout_map_; }
const LayoutMap &get_layout_map() const { return layout_map_; }

bool enable_compile_cache() { return enable_compile_cache_; }
void set_enable_compile_cache(bool enable_compile_cache) { enable_compile_cache_ = enable_compile_cache; }


+ 1
- 1
mindspore/core/load_mindir/load_model.cc View File

@@ -230,7 +230,7 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) {
model_parser.SetLite();
}
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model, weights_value_map_);
if (need_layout_) {
if (has_parallel_info_) {
layout_map_ = model_parser.ParseLayout(origin_model);
}
return dstgraph_ptr;


+ 3
- 3
mindspore/core/load_mindir/load_model.h View File

@@ -65,11 +65,11 @@ class MindIRLoader {

bool get_need_renormalize() const { return need_renormalize_; }
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
void set_need_layout(bool need_layout) { need_layout_ = need_layout; }
void set_has_parallel_info(bool has_parallel_info) { has_parallel_info_ = has_parallel_info; }
void set_weights_value_map(const std::map<string, ValuePtr> &weights_value_map) {
weights_value_map_ = weights_value_map;
}
LayoutMap get_layout_map() { return layout_map_; }
const LayoutMap &get_layout_map() { return layout_map_; }
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
FuncGraphPtr LoadMindIR(const std::string &file_name);
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
@@ -84,7 +84,7 @@ class MindIRLoader {
bool inc_load_ = false;
bool need_renormalize_ = true;
std::map<string, ValuePtr> weights_value_map_;
bool need_layout_ = false;
bool has_parallel_info_ = false;
LayoutMap layout_map_;
};



Loading…
Cancel
Save