Merge pull request !901 from rick_sanchez/mastertags/v0.3.0-alpha
| @@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP | |||
| return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); | |||
| } | |||
| void AscendSession::SetFinalGraphOutput(const BaseRef &output) { | |||
| auto final_graph = GetGraph(final_graph_id_); | |||
| MS_EXCEPTION_IF_NULL(final_graph); | |||
| if (!utils::isa<AnfNodePtr>(output)) { | |||
| if (!utils::isa<ValuePtr>(output)) { | |||
| MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); | |||
| } | |||
| auto value_ptr = utils::cast<ValuePtr>(output); | |||
| auto value_node = NewValueNode(value_ptr); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| value_node->set_kernel_info(kernel_info); | |||
| value_node->set_abstract(abstract::FromValue(value_ptr)); | |||
| final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); | |||
| final_graph->set_executable(false); | |||
| MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]"; | |||
| return; | |||
| } | |||
| void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) { | |||
| // get the backend anf node related to the output node of front | |||
| auto output_anf_node = utils::cast<AnfNodePtr>(output); | |||
| auto output_from_graph_id = GetGraphIdByNode(output_anf_node); | |||
| auto output_from_graph_id = GetGraphIdByNode(node); | |||
| auto output_from_graph = GetGraph(output_from_graph_id); | |||
| MS_EXCEPTION_IF_NULL(output_anf_node); | |||
| MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id | |||
| << "] to final graph"; | |||
| MS_EXCEPTION_IF_NULL(output_from_graph); | |||
| auto final_graph = GetGraph(final_graph_id_); | |||
| MS_EXCEPTION_IF_NULL(final_graph); | |||
| // if output is from final graph,it remarks no child graph exist | |||
| if (final_graph_id_ == output_from_graph_id) { | |||
| MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString(); | |||
| final_graph->set_output(ConstructOutput({output_anf_node}, final_graph)); | |||
| MS_LOG(INFO) << "No child graph,output is " << node->DebugString(); | |||
| final_graph->set_output(ConstructOutput({node}, final_graph)); | |||
| final_graph->set_executable(false); | |||
| return; | |||
| } | |||
| final_graph->set_output(output_from_graph->output()); | |||
| } | |||
| void AscendSession::SetFinalGraphOutput(const ValuePtr &value) { | |||
| auto value_node = NewValueNode(value); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| value_node->set_kernel_info(kernel_info); | |||
| value_node->set_abstract(abstract::FromValue(value)); | |||
| auto final_graph = GetGraph(final_graph_id_); | |||
| MS_EXCEPTION_IF_NULL(final_graph); | |||
| final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); | |||
| final_graph->set_executable(false); | |||
| MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]"; | |||
| } | |||
| void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) { | |||
| for (auto &output : vec_output) { | |||
| if (utils::isa<AnfNodePtr>(output)) { | |||
| auto output_anf_node = utils::cast<AnfNodePtr>(output); | |||
| SetFinalGraphOutput(output_anf_node); | |||
| } else if (utils::isa<ValuePtr>(output)) { | |||
| auto value = utils::cast<ValuePtr>(output); | |||
| SetFinalGraphOutput(value); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); | |||
| } | |||
| } | |||
| } | |||
| void AscendSession::SetFinalGraphOutput(const BaseRef &output) { | |||
| if (utils::isa<AnfNodePtr>(output)) { | |||
| auto output_anf_node = utils::cast<AnfNodePtr>(output); | |||
| SetFinalGraphOutput(output_anf_node); | |||
| } else if (utils::isa<ValuePtr>(output)) { | |||
| auto value = utils::cast<ValuePtr>(output); | |||
| SetFinalGraphOutput(value); | |||
| } else if (utils::isa<VectorRef>(output)) { | |||
| auto vec_output = utils::cast<VectorRef>(output); | |||
| SetFinalGraphOutput(vec_output); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); | |||
| } | |||
| } | |||
| KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) { | |||
| auto it = graphs_.find(graph_id); | |||
| if (it == graphs_.end()) { | |||
| @@ -88,6 +88,10 @@ class AscendSession : public SessionBasic { | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); | |||
| void SetFinalGraphOutput(const AnfNodePtr &node); | |||
| void SetFinalGraphOutput(const ValuePtr &value); | |||
| void SetFinalGraphOutput(const VectorRef &vec_output); | |||
| // merge execution order list of child graphs | |||
| void MergeGraphExecOrder(); | |||
| // insert assion op to sync data bettween different graphs | |||
| @@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) { | |||
| AddInst(Instruction::kCall, args); | |||
| args.clear(); | |||
| args.emplace_back(true); | |||
| args.emplace_back(node->input(1)); | |||
| AddInst(Instruction::kSwitchReturn, args); | |||
| args.clear(); | |||
| @@ -141,17 +141,31 @@ void FinalVM::Popsp() { | |||
| } | |||
| } | |||
| void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); } | |||
| bool FinalVM::PopStatus() { | |||
| if (ret_status_.empty()) { | |||
| return false; | |||
| } | |||
| bool status = ret_status_.top(); | |||
| ret_status_.pop(); | |||
| return status; | |||
| } | |||
| void FinalVM::DoJmp(const BaseRef &jmp_orig) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| BaseRef jmp = jmp_orig; | |||
| if (backend_->simu_flag()) { | |||
| bool is_switch_call = false; | |||
| if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base | |||
| MS_LOG(DEBUG) << "Start jump StructSwitch"; | |||
| auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp); | |||
| jmp = simu_value->fn_; | |||
| backend_->set_curr_switch(simu_value->value_); | |||
| is_switch_call = true; | |||
| } | |||
| PushStatus(is_switch_call); | |||
| } | |||
| if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base | |||
| @@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) { | |||
| MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; | |||
| return; | |||
| } | |||
| auto rv = Ref(-1); | |||
| if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) { | |||
| auto &c = args[0]; | |||
| cond_out_[c] = rv; | |||
| } | |||
| Pop(1); | |||
| Popsp(); | |||
| } | |||
| @@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) { | |||
| int height = utils::cast<int>(args[1]); | |||
| auto rv = Ref(rpos); | |||
| if (backend_->simu_flag() && backend_->is_switch_call()) { | |||
| backend_->SetSwitchGraph(); | |||
| if (backend_->simu_flag()) { | |||
| auto c = backend_->curr_switch(); | |||
| auto status = PopStatus(); | |||
| if (status) { | |||
| auto iter = cond_out_.find(c); | |||
| if (iter != cond_out_.end()) { | |||
| rv = MergeArgs(rv, iter->second); | |||
| cond_out_.erase(iter); | |||
| } | |||
| } | |||
| if (backend_->is_switch_call()) { | |||
| backend_->SetSwitchGraph(); | |||
| } | |||
| } | |||
| Pop(height); | |||
| @@ -383,21 +416,30 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) { | |||
| for (size_t i = 0; i < new_args.size(); ++i) { | |||
| auto &old_arg = old_args[i]; | |||
| auto &new_arg = new_args[i]; | |||
| if (utils::isa<VectorRef>(old_arg)) { | |||
| auto old_vec_ref = utils::cast<VectorRef>(old_arg); | |||
| if (utils::isa<VectorRef>(new_arg)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(new_arg); | |||
| std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); | |||
| } | |||
| new_arg = old_vec_ref; | |||
| } else if (utils::isa<VectorRef>(new_arg)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(new_arg); | |||
| new_vec_ref.push_back(old_arg); | |||
| new_arg = new_vec_ref; | |||
| new_arg = MergeArgs(old_arg, new_arg); | |||
| } | |||
| } | |||
| BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) { | |||
| MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString(); | |||
| if (utils::isa<VectorRef>(first)) { | |||
| auto old_vec_ref = utils::cast<VectorRef>(first); | |||
| if (utils::isa<VectorRef>(second)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(second); | |||
| std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); | |||
| } else { | |||
| new_arg = VectorRef({new_arg, old_arg}); | |||
| old_vec_ref.push_back(second); | |||
| } | |||
| return old_vec_ref; | |||
| } | |||
| if (utils::isa<VectorRef>(second)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(second); | |||
| new_vec_ref.push_back(first); | |||
| return new_vec_ref; | |||
| } | |||
| return VectorRef({first, second}); | |||
| } | |||
| void FinalVM::InstRealSwitch(const VectorRef &args) { | |||
| @@ -125,17 +125,22 @@ class FinalVM { | |||
| void Popp(); | |||
| void Pushsp(); | |||
| void Popsp(); | |||
| void PushStatus(bool is_switch_call); | |||
| bool PopStatus(); | |||
| void DoJmp(const BaseRef &jmp); | |||
| void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); | |||
| BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); | |||
| private: | |||
| InstSet insts_; | |||
| std::deque<BaseRef> insts_stack_; | |||
| std::stack<int> retp_; | |||
| std::stack<int> retsp_; | |||
| std::stack<bool> ret_status_; | |||
| int pc_; | |||
| int sp_; | |||
| std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_; | |||
| std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_; | |||
| BackendPtr backend_; | |||
| const InstFunctionMap inst_function_map = { | |||
| {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, | |||
| @@ -26,6 +26,7 @@ from mindspore.ops import operations as P | |||
| def setup_module(module): | |||
| context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend") | |||
| c1 = Tensor([2], mstype.int32) | |||
| c2 = Tensor([14], mstype.int32) | |||
| c3 = Tensor([1], mstype.int32) | |||
| @@ -149,6 +150,10 @@ def test_if_by_if(): | |||
| assert output == expect | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_if_in_if(): | |||
| output = if_in_if(c1, c2, c3) | |||
| expect = Tensor([7], mstype.int32) | |||
| @@ -194,6 +199,7 @@ def test_while_by_while_in_while(): | |||
| expect = Tensor([350], mstype.int32) | |||
| assert output == expect | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||