Browse Source

using env variable for INFER and TRAIN mode

feature/build-system-rewrite
xiao_yao1994 4 years ago
parent
commit
52817c2bbd
4 changed files with 20 additions and 11 deletions
  1. +10
    -4
      mindspore/ccsrc/pipeline/jit/pipeline_ge.cc
  2. +8
    -1
      mindspore/ccsrc/transform/graph_ir/convert.h
  3. +0
    -5
      mindspore/ccsrc/utils/config_manager.h
  4. +2
    -1
      mindspore/ccsrc/utils/context/context_extends.cc

+ 10
- 4
mindspore/ccsrc/pipeline/jit/pipeline_ge.cc View File

@@ -21,7 +21,6 @@
#include <cstdlib>
#include <algorithm>

#include "utils/config_manager.h"
#include "utils/hash_map.h"
#include "debug/anf_ir_dump.h"
#include "ir/tensor.h"
@@ -130,7 +129,11 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
return false;
}

auto training = ConfigManager::GetInstance().training();
auto env_training = common::GetEnv("MS_GE_TRAIN");
bool training = false;
if (env_training == "1") {
training = true;
}
if (training) {
(void)setenv("GE_TRAIN", "1", 1);
} else {
@@ -246,7 +249,6 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase);
}
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
ConfigManager::GetInstance().set_training(anf_graph->has_flag("training"));
#ifdef ENABLE_DUMP_IR
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
draw::Draw("anf_graph.dot", anf_graph); // for debug
@@ -259,7 +261,11 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
return nullptr;
}

auto training = ConfigManager::GetInstance().training();
auto env_training = common::GetEnv("MS_GE_TRAIN");
bool training = false;
if (env_training == "1") {
training = true;
}
if (training) {
(void)setenv("GE_TRAIN", "1", 1);
} else {


+ 8
- 1
mindspore/ccsrc/transform/graph_ir/convert.h View File

@@ -19,6 +19,7 @@

#define DRAW_GE_GRAPH

#include <cstdlib>
#include <memory>
#include <map>
#include <set>
@@ -56,7 +57,13 @@ class DfGraphConvertor {
explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) {
MS_EXCEPTION_IF_NULL(anf_graph);
df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString());
training_ = anf_graph->has_flag("training");
auto env_ge = mindspore::common::GetEnv("MS_ENABLE_GE");
auto env_training = mindspore::common::GetEnv("MS_GE_TRAIN");
if (env_ge == "1" && env_training == "1") {
training_ = true;
} else {
training_ = anf_graph->has_flag("training");
}
distribute_ = anf_graph->has_flag("broadcast_flag");
if (anf_graph->has_flag("broadcast_flag")) {
ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION);


+ 0
- 5
mindspore/ccsrc/utils/config_manager.h View File

@@ -117,10 +117,6 @@ class ConfigManager {

void set_gpu_loopsink_size(const int64_t size) { gpu_loopsink_size_ = size; }

bool training() const { return training_; }

void set_training(const bool training) { training_ = training; }

private:
ConfigManager() = default;
~ConfigManager() = default;
@@ -134,7 +130,6 @@ class ConfigManager {
std::map<std::string, int16_t> queue_info_map;
std::string dataset_phase_{""};
int64_t gpu_loopsink_size_{1};
bool training_{false};
};

} // namespace mindspore


+ 2
- 1
mindspore/ccsrc/utils/context/context_extends.cc View File

@@ -196,7 +196,8 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
}

if (ConfigManager::GetInstance().training()) {
auto training = common::GetEnv("MS_GE_TRAIN");
if (training == "1") {
(*ge_options)["ge.graphRunMode"] = "1";
}



Loading…
Cancel
Save