Browse Source

pynative-insert-transdata-for-hook-mode

tags/v0.5.0-beta
lvliang 5 years ago
parent
commit
075da9a4b1
4 changed files with 11 additions and 1 deletions
  1. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc
  2. +1
    -0
      mindspore/ccsrc/utils/context/ms_context.cc
  3. +4
    -0
      mindspore/ccsrc/utils/context/ms_context.h
  4. +5
    -0
      mindspore/ccsrc/vm/transform.cc

+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc View File

@@ -51,7 +51,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node;
}


+ 1
- 0
mindspore/ccsrc/utils/context/ms_context.cc View File

@@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
precompile_only_ = false;
auto_mixed_precision_flag_ = false;
enable_pynative_infer_ = false;
enable_pynative_hook_ = false;
enable_dynamic_mem_pool_ = true;
graph_memory_max_size_ = "0";
variable_memory_max_size_ = "0";


+ 4
- 0
mindspore/ccsrc/utils/context/ms_context.h View File

@@ -64,6 +64,9 @@ class MsContext {
bool enable_pynative_infer() const { return enable_pynative_infer_; }
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }

bool enable_pynative_hook() const { return enable_pynative_hook_; }
void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; }

bool enable_task_sink() const { return enable_task_sink_; }

void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
@@ -161,6 +164,7 @@ class MsContext {
uint32_t device_id_;
int execution_mode_;
bool enable_pynative_infer_;
bool enable_pynative_hook_;
bool save_graphs_flag_;
std::string save_graphs_path_;
uint32_t tsd_ref_;


+ 5
- 0
mindspore/ccsrc/vm/transform.cc View File

@@ -277,6 +277,11 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
for (auto &prim : cut_list_) {
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == node_prim->name()) {
if (prim->name() == prim::kPrimBpropCut->name()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_enable_pynative_hook(true);
}
return true;
}
}


Loading…
Cancel
Save