Browse Source

fix-bug-avoid-multi-attr-value-be-eliminated-in-pynative-mode

tags/v0.7.0-beta
lvliang 5 years ago
parent
commit
e1a3c39fac
3 changed files with 14 additions and 3 deletions
  1. +8
    -1
      mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc
  2. +3
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
  3. +3
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc

+ 8
- 1
mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc View File

@@ -62,7 +62,14 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes; std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
MS_EXCEPTION_IF_NULL(real_input);
if (!real_input->isa<Parameter>() && !real_input->isa<ValueNode>()) {
return nullptr;
}
}
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
return nullptr; return nullptr;
} }
bool cnode_input_changed = false; bool cnode_input_changed = false;


+ 3
- 0
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc View File

@@ -41,6 +41,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
} }
// Prim Eliminate (identity) // Prim Eliminate (identity)
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}


// ConstantDuplicateMul // ConstantDuplicateMul
auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr { auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {


+ 3
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -393,9 +393,7 @@ bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_i
ValuePtr value = parse::data_converter::PyDataToValue(input_object); ValuePtr value = parse::data_converter::PyDataToValue(input_object);
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
auto input_name = input_names_vec[input_index]; auto input_name = input_names_vec[input_index];
op_prim->BeginRecordAddAttr();
op_prim->AddAttr(input_name, value); op_prim->AddAttr(input_name, value);
op_prim->EndRecordAddAttr();
return true; return true;
} }
return false; return false;
@@ -499,6 +497,8 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te


opt::ConstInputToAttrInfoRegister reg; opt::ConstInputToAttrInfoRegister reg;
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg); bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);

op_prim->BeginRecordAddAttr();
size_t input_num = op_run_info->op_inputs.size(); size_t input_num = op_run_info->op_inputs.size();
for (size_t index = 0; index < input_num; ++index) { for (size_t index = 0; index < input_num; ++index) {
// convert const input to attr // convert const input to attr
@@ -513,6 +513,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
} }
op_prim->EndRecordAddAttr();
} }


void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) { void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {


Loading…
Cancel
Save