Merge pull request !694 from rick_sanchez/mastertags/v0.3.0-alpha
| @@ -800,45 +800,77 @@ void AscendSession::UpdateGraphOrder(GraphId to_graph_id) { | |||
| } | |||
| } | |||
| size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) { | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| return input_index + output_num; | |||
| } | |||
| auto &graph_inputs = graph->inputs(); | |||
| auto &valid_inputs = graph->ValidInputs(); | |||
| if (valid_inputs[input_index]) { | |||
| SetChildGraphParameter(node, graph_inputs[input_index]); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); | |||
| } | |||
| return ++input_index; | |||
| } | |||
| size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (!value->isa<Tensor>()) { | |||
| MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); | |||
| } | |||
| auto &graph_inputs = graph->inputs(); | |||
| SetChildGraphParameter(value->cast<TensorPtr>(), graph_inputs[input_index]); | |||
| return ++input_index; | |||
| } | |||
| size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) { | |||
| auto index = input_index; | |||
| for (auto &arg : vec_args) { | |||
| if (utils::isa<AnfNodePtr>(arg)) { | |||
| // arg is a anf node | |||
| auto node = utils::cast<AnfNodePtr>(arg); | |||
| index = SetChildGraphInput(graph, node, input_index); | |||
| } else if (utils::isa<ValuePtr>(arg)) { | |||
| // arg is a tensor | |||
| auto value = utils::cast<ValuePtr>(arg); | |||
| index = SetChildGraphInput(graph, value, input_index); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString(); | |||
| } | |||
| } | |||
| return index; | |||
| } | |||
| void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { | |||
| MS_LOG(INFO) << "Set input of graph " << g; | |||
| auto to_graph = GetGraph(g); | |||
| MS_EXCEPTION_IF_NULL(to_graph); | |||
| DumpGraphInputArgs(args); | |||
| UpdateGraphOrder(g); | |||
| std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); | |||
| auto valid_inputs = to_graph->ValidInputs(); | |||
| auto &graph_inputs = to_graph->inputs(); | |||
| auto real_args = GetRealArgs(to_graph, args); | |||
| size_t input_index = 0; | |||
| for (size_t i = 0; i < real_args.size(); i++) { | |||
| if (input_index >= graph_inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); | |||
| } | |||
| if (utils::isa<AnfNodePtr>(real_args[i])) { | |||
| auto &real_arg = real_args[i]; | |||
| if (utils::isa<AnfNodePtr>(real_arg)) { | |||
| // arg is a anf node | |||
| auto real_arg = utils::cast<AnfNodePtr>(real_args[i]); | |||
| auto real_arg_output_num = AnfAlgo::GetOutputTensorNum(real_arg); | |||
| if (!AnfAlgo::CheckPrimitiveType(real_arg, prim::kPrimTupleGetItem) && real_arg_output_num > 1) { | |||
| input_index += real_arg_output_num; | |||
| continue; | |||
| } | |||
| if (valid_inputs[input_index]) { | |||
| SetChildGraphParameter(real_arg, graph_inputs[input_index]); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); | |||
| } | |||
| input_index++; | |||
| } else if (utils::isa<ValuePtr>(args[i])) { | |||
| auto value = utils::cast<ValuePtr>(args[i]); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| auto node = utils::cast<AnfNodePtr>(real_arg); | |||
| input_index = SetChildGraphInput(to_graph, node, input_index); | |||
| } else if (utils::isa<ValuePtr>(real_arg)) { | |||
| // arg is a tensor | |||
| if (!value->isa<Tensor>()) { | |||
| MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); | |||
| } | |||
| SetChildGraphParameter(value->cast<TensorPtr>(), graph_inputs[input_index]); | |||
| input_index++; | |||
| auto value = utils::cast<ValuePtr>(real_arg); | |||
| input_index = SetChildGraphInput(to_graph, value, input_index); | |||
| } else if (utils::isa<VectorRef>(real_arg)) { | |||
| // arg is a VectorRef | |||
| auto vec_args = utils::cast<VectorRef>(real_arg); | |||
| input_index = SetChildGraphInput(to_graph, vec_args, input_index); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unexpected arg type " << args[i].ToString(); | |||
| MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString(); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Finish!"; | |||
| @@ -79,6 +79,10 @@ class AscendSession : public SessionBasic { | |||
| void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const; | |||
| void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); | |||
| // merge execution order list of child graphs | |||
| void MergeGraphExecOrder(); | |||
| // insert assion op to sync data bettween different graphs | |||
| @@ -88,7 +88,7 @@ class KernelGraph : public FuncGraph { | |||
| void set_executable(bool executable) { executable_ = executable; } | |||
| // set invalid inputs for control sink | |||
| std::vector<bool> *MutableValidInputs() { return &valid_inputs_; } | |||
| std::vector<bool> ValidInputs() { return valid_inputs_; } | |||
| const std::vector<bool> &ValidInputs() const { return valid_inputs_; } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -228,6 +228,8 @@ T cast(const BaseRef &handle) { | |||
| class VectorRef : public BaseRef { | |||
| public: | |||
| using value_type = BaseRef; | |||
| VectorRef() {} | |||
| explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {} | |||
| VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} | |||
| @@ -251,6 +253,13 @@ class VectorRef : public BaseRef { | |||
| return elements_[dim]; | |||
| } | |||
| BaseRef &operator[](const std::size_t &dim) { | |||
| if (dim >= size()) { | |||
| MS_LOG(EXCEPTION) << "Out of the size of the tuple."; | |||
| } | |||
| return elements_[dim]; | |||
| } | |||
| uint32_t type() const override { return tid(); } | |||
| std::string ToString() const override; | |||
| std::vector<BaseRef> &elements() { return elements_; } | |||
| @@ -143,6 +143,66 @@ void MsBackend::SetSwitchGraph() { | |||
| } | |||
| } | |||
| // convert node from formal parameter to actual parameter, | |||
| // and actual parameter is graph user's formal parameter. | |||
| // get top while graph's parameter in recall while. | |||
| AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| std::unordered_map<AnfNodePtr, size_t> params_index; | |||
| auto result = node; | |||
| auto graph = result->func_graph(); | |||
| while (func_graph != graph) { | |||
| auto iter = graph_user_inputs_.find(graph); | |||
| if (iter == graph_user_inputs_.end()) { | |||
| break; | |||
| } | |||
| params_index.clear(); | |||
| auto ¶ms = graph->parameters(); | |||
| for (size_t i = 0; i < params.size(); ++i) { | |||
| params_index[params[i]] = i; | |||
| } | |||
| graph = iter->second.first; | |||
| auto &inputs = iter->second.second; | |||
| result = inputs[params_index[result]]; | |||
| } | |||
| return result; | |||
| } | |||
| void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user, | |||
| const AnfNodePtrList &inputs) { | |||
| if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) { | |||
| return; | |||
| } | |||
| graph_user_inputs_[func_graph] = {user, inputs}; | |||
| } | |||
| void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) { | |||
| std::unordered_map<AnfNodePtr, size_t> params_index; | |||
| auto ¶ms = func_graph->parameters(); | |||
| for (size_t i = 0; i < params.size(); ++i) { | |||
| params_index[params[i]] = i; | |||
| } | |||
| // recall all child graphs in this while | |||
| auto &graph_inputs = graph_inputs_[c]; | |||
| for (auto &iter : graph_inputs) { | |||
| auto &graph = iter.first; | |||
| auto &old_args = iter.second; | |||
| auto &result = graph_id_map_[graph]; | |||
| auto &inputs = result.inputs; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| auto input = ConvertGraphInput(func_graph, inputs[i]); | |||
| auto it = params_index.find(input); | |||
| if (it != params_index.end()) { | |||
| old_args[i] = args[it->second]; | |||
| } | |||
| } | |||
| sess_->SetChildGraphInput(graph, old_args); | |||
| } | |||
| graph_inputs_.erase(c); | |||
| } | |||
| // compile set input output | |||
| VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "set graph input:" << g; | |||
| @@ -150,13 +210,20 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | |||
| sess_->SetChildGraphInput(g, args); | |||
| if (is_switch_call_) { | |||
| bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; | |||
| MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond; | |||
| if (0 == simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) { | |||
| MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g; | |||
| simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g; | |||
| SetSwitchGraph(); | |||
| if (!curr_switch_.is_null()) { | |||
| // push this {g, args} to all user while graph_inputs for nest while, | |||
| // when current condition recall over delete this cond in graph_inputs. | |||
| for (auto &iter : graph_inputs_) { | |||
| iter.second.push_back({g, args}); | |||
| } | |||
| if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) { | |||
| graph_inputs_[curr_switch_].push_back({g, args}); | |||
| } | |||
| } | |||
| bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; | |||
| MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g; | |||
| simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g; | |||
| SetSwitchGraph(); | |||
| } | |||
| std::vector<BaseRef> outputs; | |||
| @@ -205,42 +272,17 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { | |||
| return outputs; | |||
| } | |||
| void MsBackend::SetSimuCondFlag(const BaseRef &c, int flag) { | |||
| MS_LOG(DEBUG) << "while set cond :" << c.ToString() << ", " << simu_cond_map_.size(); | |||
| if (simu_cond_map_.find(c) == simu_cond_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "error c not find"; | |||
| } | |||
| simu_cond_map_[c].flag = flag; | |||
| } | |||
| int MsBackend::GetSimuCondFlag(const BaseRef &c) { | |||
| BaseRef cond = c; | |||
| if (cond.is_null()) { | |||
| MS_LOG(DEBUG) << "get curr_switch"; | |||
| cond = curr_switch_; | |||
| } | |||
| if (simu_cond_map_.find(cond) == simu_cond_map_.end()) { | |||
| MS_LOG(ERROR) << "error c not find"; | |||
| return -1; | |||
| } | |||
| return simu_cond_map_[cond].flag; | |||
| } | |||
| SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) { | |||
| MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size(); | |||
| CondGraph cond_graph; | |||
| cond_graph.curr_cond = value; | |||
| if (simu_cond_map_.find(c) == simu_cond_map_.end()) { | |||
| cond_graph.flag = 0; | |||
| simu_cond_map_[c] = cond_graph; | |||
| } | |||
| if (simu_cond_map_[c].cond_graph_map.count(value)) { | |||
| if (value == true) { | |||
| return kCondAlreadyRun; | |||
| } | |||
| return kCondAlreadyRun; | |||
| } | |||
| simu_cond_map_[c].curr_cond = value; | |||
| MS_LOG(DEBUG) << "end set cond "; | |||
| @@ -16,9 +16,11 @@ | |||
| #ifndef MINDSPORE_CCSRC_VM_BACKEND_H_ | |||
| #define MINDSPORE_CCSRC_VM_BACKEND_H_ | |||
| #include <string> | |||
| #include <list> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "ir/anf.h" | |||
| #include "vm/segment_runner.h" | |||
| @@ -45,6 +47,8 @@ class Backend { | |||
| virtual bool GetCond(const BaseRef &c, bool *value); | |||
| virtual void SetSwitchGraph() {} | |||
| virtual void SetSwitchActive(const BaseRef &, bool) {} | |||
| virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} | |||
| virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} | |||
| void set_curr_switch(const BaseRef &value) { | |||
| curr_switch_ = value; | |||
| @@ -54,8 +58,6 @@ class Backend { | |||
| BaseRef curr_switch() { return curr_switch_; } | |||
| virtual void Link(GraphId) {} | |||
| virtual LinConvertResult GetMultiGraphRun(const FuncGraphPtr &) { return LinConvertResult(); } | |||
| virtual void SetSimuCondFlag(const BaseRef &, int) {} | |||
| virtual int GetSimuCondFlag(const BaseRef &) { return 0; } | |||
| LinConvertResult multi_result() { return multi_result_; } | |||
| void set_multi_result(const LinConvertResult &value) { multi_result_ = value; } | |||
| @@ -75,11 +77,11 @@ class Backend { | |||
| bool simu_flag_; | |||
| LinConvertResult multi_result_; | |||
| AnfNodePtr final_output_; | |||
| std::unordered_map<FuncGraphPtr, std::pair<FuncGraphPtr, AnfNodePtrList>> graph_user_inputs_; | |||
| }; | |||
| struct CondGraph { | |||
| bool curr_cond; | |||
| int flag; | |||
| std::unordered_map<bool, GraphId> cond_graph_map; | |||
| }; | |||
| @@ -97,15 +99,17 @@ class MsBackend : public Backend { | |||
| void SetSwitchGraph() override; | |||
| void SetSwitchActive(const BaseRef &c, bool cond) override; | |||
| void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) override; | |||
| void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) override; | |||
| void Link(GraphId) override; | |||
| AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); | |||
| LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; | |||
| void SetSimuCondFlag(const BaseRef &c, int flag) override; | |||
| int GetSimuCondFlag(const BaseRef &c) override; | |||
| private: | |||
| session::SessionPtr sess_; | |||
| std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_; | |||
| std::unordered_map<GraphId, LinConvertResult> graph_id_map_; | |||
| std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_; | |||
| }; | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| @@ -390,6 +390,16 @@ void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) { | |||
| void CompileGraph::AddPartial(const CNodePtr &node) { | |||
| auto inputs = node->inputs(); | |||
| VectorRef args; | |||
| auto fn = inputs[1]; | |||
| if (!IsValueNode<FuncGraph>(fn)) { | |||
| MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph"; | |||
| } | |||
| if (backend_->is_multi_graph_sink()) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(fn); | |||
| args.emplace_back(func_graph); | |||
| AnfNodePtrList outs(inputs.begin() + 2, inputs.end()); | |||
| backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); | |||
| } | |||
| for (size_t i = 1; i < inputs.size(); i++) { | |||
| args.emplace_back(Ref(inputs[i])); | |||
| } | |||
| @@ -442,12 +452,17 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) | |||
| } | |||
| int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { | |||
| auto node_inputs = node->inputs(); | |||
| AnfNodePtr fn = node_inputs[0]; | |||
| auto inputs = node->inputs(); | |||
| AnfNodePtr fn = inputs[0]; | |||
| if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(fn); | |||
| AnfNodePtrList outs(inputs.begin() + 1, inputs.end()); | |||
| backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); | |||
| } | |||
| (void)Ref(fn); | |||
| size_t size = node_inputs.size(); | |||
| size_t size = inputs.size(); | |||
| for (size_t i = size - 1; i > 0; i--) { | |||
| AddInput(node_inputs[i]); | |||
| AddInput(inputs[i]); | |||
| } | |||
| if (node == graph->output()) { | |||
| AddTailCall(fn, size); | |||
| @@ -32,7 +32,8 @@ namespace compile { | |||
| // Arguments: | |||
| // fn_: Callable function. | |||
| // args_: Sequence of function args. | |||
| StructPartial::StructPartial(int fn, const VectorRef &args) : fn_(fn), args_(args) {} | |||
| // fg_: Graph of function. | |||
| StructPartial::StructPartial(int fn, const VectorRef &args, const FuncGraphPtr &fg) : fn_(fn), args_(args), fg_(fg) {} | |||
| std::ostream &operator<<(std::ostream &os, const StructPartial &other) { | |||
| os << "partial(" << other.fn_ << ", " << other.args_.ToString() << ")"; | |||
| @@ -40,7 +41,7 @@ std::ostream &operator<<(std::ostream &os, const StructPartial &other) { | |||
| } | |||
| bool operator==(const StructPartial &lhs, const StructPartial &rhs) { | |||
| return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_); | |||
| return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_ && lhs.fg_ == rhs.fg_); | |||
| } | |||
| StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {} | |||
| @@ -242,16 +243,6 @@ void FinalVM::InstTailCall(const VectorRef &args) { | |||
| int nargs = utils::cast<int>(args[2]); | |||
| auto new_jmp = Ref(jmp); | |||
| if (backend_->simu_flag()) { | |||
| if (backend_->GetSimuCondFlag(BaseRef()) == 2) { | |||
| MS_LOG(DEBUG) << "invoke while call tail first"; | |||
| Pop(height); | |||
| Push(1); | |||
| Popp(); | |||
| return; | |||
| } | |||
| } | |||
| MoveStack(nargs, height); | |||
| MS_LOG(DEBUG) << "TailCall pushp:" << pc_ << ", jmp:" << jmp; | |||
| DoJmp(new_jmp); | |||
| @@ -291,8 +282,30 @@ void FinalVM::InstReturn(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| void FinalVM::InstPartial(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| void FinalVM::InstSimuPartial(const VectorRef &args) { | |||
| const size_t args_size = 2; | |||
| if (args.size() < args_size) { | |||
| MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is " | |||
| << args.size() << "."; | |||
| return; | |||
| } | |||
| auto &node = args[0]; | |||
| if (!utils::isa<FuncGraphPtr>(node)) { | |||
| MS_LOG(ERROR) << "The type of 1st input of node must be FuncGraph"; | |||
| return; | |||
| } | |||
| auto fg = utils::cast<FuncGraphPtr>(node); | |||
| int fn_ = utils::cast<int>(args[1]); | |||
| auto fn = utils::cast<int>(Ref(fn_)); | |||
| MS_LOG(DEBUG) << "Partial argssize:" << args.size(); | |||
| std::vector<BaseRef> outs(args.size() - 2); | |||
| (void)std::transform(args.begin() + 2, args.end(), outs.begin(), | |||
| [&, this](const BaseRef &a) { return Ref(utils::cast<int>(a)); }); | |||
| Push(std::make_shared<StructPartial>(fn, VectorRef(outs), fg)); | |||
| } | |||
| void FinalVM::InstRealPartial(const VectorRef &args) { | |||
| const size_t args_size = 1; | |||
| if (args.size() < args_size) { | |||
| MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is " | |||
| @@ -304,10 +317,18 @@ void FinalVM::InstPartial(const VectorRef &args) { | |||
| auto fn = utils::cast<int>(Ref(fn_)); | |||
| MS_LOG(DEBUG) << "Partial argssize:" << args.size(); | |||
| std::vector<BaseRef> outs(args.size() - 1); | |||
| (void)std::transform(args.begin() + 1, args.end(), outs.begin(), | |||
| [&, this](const BaseRef &a) { return Ref(utils::cast<int>(a)); }); | |||
| Push(std::make_shared<StructPartial>(fn, VectorRef(outs))); | |||
| } | |||
| void FinalVM::InstPartial(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| if (backend_->is_multi_graph_sink()) { | |||
| InstSimuPartial(args); | |||
| } else { | |||
| InstRealPartial(args); | |||
| } | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| @@ -328,43 +349,57 @@ void FinalVM::InstSimuSwitch(const VectorRef &args) { | |||
| bool bool_value = cond; | |||
| SwitchCondStatus cond_stat = backend_->SetSimuCond(c, bool_value); | |||
| int cond_flag = backend_->GetSimuCondFlag(c); | |||
| MS_LOG(DEBUG) << "Simu switch cond:" << cond << ", " << cond_flag << ", " << c.cast<AnfNodePtr>()->DebugString(); | |||
| if (cond_flag == 2) { | |||
| Popp(); | |||
| Popp(); | |||
| backend_->SetSimuCondFlag(c, 0); | |||
| return; | |||
| } | |||
| if (cond_stat == kCondAlreadyRun) { | |||
| MS_LOG(DEBUG) << "switch alreay run bool while true jmp"; | |||
| if (cond_flag == 0) { | |||
| MS_LOG(DEBUG) << "switch second run bool while true jmp"; | |||
| backend_->SetSwitchActive(c, true); | |||
| Push(std::make_shared<StructSimuSwitch>(Ref(vtrue), c)); | |||
| Pushsp(); | |||
| backend_->SetSimuCondFlag(c, 1); | |||
| return; | |||
| } else if (cond_flag == 1) { | |||
| MS_LOG(DEBUG) << "switch first run bool while if jmp"; | |||
| Push(std::make_shared<StructSimuSwitch>(Ref(vfalse), c)); | |||
| (void)backend_->SetSimuCond(c, false); | |||
| backend_->SetSimuCondFlag(c, 2); | |||
| return; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "error cond not find"; | |||
| return; | |||
| BaseRef jmp = Ref(vtrue); | |||
| if (utils::isa<StructPartial>(jmp)) { | |||
| auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp); | |||
| backend_->RecallGraphInput(new_jmp->fg_, new_jmp->args_, c); | |||
| } | |||
| cond_jmp_[c] = Ref(vfalse); | |||
| Push(static_cast<int>(cond_stat)); | |||
| Popp(); | |||
| backend_->SetSwitchActive(c, bool_value); | |||
| return; | |||
| } | |||
| if (bool_value) { | |||
| Push(std::make_shared<StructSimuSwitch>(Ref(vtrue), c)); | |||
| Pushsp(); | |||
| } else { | |||
| MergeJmpArgs(Ref(vfalse), c); | |||
| Push(std::make_shared<StructSimuSwitch>(Ref(vfalse), c)); | |||
| } | |||
| } | |||
| void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) { | |||
| auto iter = cond_jmp_.find(c); | |||
| if (iter == cond_jmp_.end()) { | |||
| return; | |||
| } | |||
| auto old_jmp = utils::cast<std::shared_ptr<StructPartial>>(iter->second); | |||
| auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp); | |||
| auto &old_args = old_jmp->args_; | |||
| auto &new_args = new_jmp->args_; | |||
| for (size_t i = 0; i < new_args.size(); ++i) { | |||
| auto &old_arg = old_args[i]; | |||
| auto &new_arg = new_args[i]; | |||
| if (utils::isa<VectorRef>(old_arg)) { | |||
| auto old_vec_ref = utils::cast<VectorRef>(old_arg); | |||
| if (utils::isa<VectorRef>(new_arg)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(new_arg); | |||
| std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); | |||
| } | |||
| new_arg = old_vec_ref; | |||
| } else if (utils::isa<VectorRef>(new_arg)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(new_arg); | |||
| new_vec_ref.push_back(old_arg); | |||
| new_arg = new_vec_ref; | |||
| } else { | |||
| new_arg = VectorRef({new_arg, old_arg}); | |||
| } | |||
| } | |||
| } | |||
| void FinalVM::InstRealSwitch(const VectorRef &args) { | |||
| const size_t args_size = 3; | |||
| if (args.size() != args_size) { | |||
| @@ -399,6 +434,7 @@ void FinalVM::InstSwitch(const VectorRef &args) { | |||
| } else { | |||
| InstRealSwitch(args); | |||
| } | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| void FinalVM::InstTuple(const VectorRef &args) { | |||
| @@ -27,6 +27,9 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <deque> | |||
| #include <unordered_map> | |||
| #include "ir/anf.h" | |||
| #include "utils/base_ref.h" | |||
| namespace mindspore { | |||
| @@ -60,13 +63,14 @@ const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial | |||
| class StructPartial : public Base { | |||
| public: | |||
| // Initialize StructPartial. | |||
| StructPartial(int fn, const VectorRef &args); | |||
| StructPartial(int fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr); | |||
| virtual ~StructPartial() = default; | |||
| MS_DECLARE_PARENT(StructPartial, Base) | |||
| int fn_; | |||
| VectorRef args_; | |||
| FuncGraphPtr fg_; | |||
| }; | |||
| std::ostream &operator<<(std::ostream &os, const StructPartial &other); | |||
| @@ -98,6 +102,8 @@ class FinalVM { | |||
| void InstTailCall(const VectorRef &args); | |||
| void InstReturn(const VectorRef &args); | |||
| void InstPartial(const VectorRef &args); | |||
| void InstSimuPartial(const VectorRef &args); | |||
| void InstRealPartial(const VectorRef &args); | |||
| void InstSwitch(const VectorRef &args); | |||
| void InstSimuSwitch(const VectorRef &args); | |||
| void InstRealSwitch(const VectorRef &args); | |||
| @@ -120,6 +126,7 @@ class FinalVM { | |||
| void Pushsp(); | |||
| void Popsp(); | |||
| void DoJmp(const BaseRef &jmp); | |||
| void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); | |||
| private: | |||
| InstSet insts_; | |||
| @@ -128,6 +135,7 @@ class FinalVM { | |||
| std::stack<int> retsp_; | |||
| int pc_; | |||
| int sp_; | |||
| std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_; | |||
| BackendPtr backend_; | |||
| const InstFunctionMap inst_function_map = { | |||
| {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, | |||
| @@ -0,0 +1,184 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_multigraph_sink """ | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common import ms_function | |||
| from mindspore.ops import operations as P | |||
| def setup_module(module): | |||
| context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") | |||
| context.set_context(enable_task_sink = True, device_id = 0) | |||
| c1 = Tensor([2], mstype.int32) | |||
| c2 = Tensor([14], mstype.int32) | |||
| c3 = Tensor([1], mstype.int32) | |||
| c4 = Tensor([0], mstype.int32) | |||
| c5 = Tensor([14], mstype.int32) | |||
| @ms_function | |||
| def simple_if(x, y, z): | |||
| if x < y: | |||
| x = x + 1 | |||
| else: | |||
| x = x + 2 | |||
| x = x + 3 | |||
| return x | |||
| @ms_function | |||
| def if_by_if(x, y, z): | |||
| if x < y: | |||
| x = x + 1 | |||
| if y > x: | |||
| x = x + 2 | |||
| x = x + 3 | |||
| return x | |||
| @ms_function | |||
| def if_in_if(x, y, z): | |||
| out = c4 | |||
| if x < y: | |||
| z = c4 + c4 | |||
| if z < y: | |||
| z = z + 2 | |||
| out = out + z | |||
| x = x + 3 | |||
| out = out + x | |||
| return out | |||
| @ms_function | |||
| def simple_while(x, y, z): | |||
| y = y + 4 | |||
| while x < y: | |||
| x = x + 1 | |||
| x = x + 3 | |||
| return x | |||
| @ms_function | |||
| def while_by_while(x, y, z): | |||
| while x < y: | |||
| x = x + 1 | |||
| while z < c5: | |||
| z = z + 1 | |||
| x = x + 1 | |||
| x = x + 1 | |||
| return x | |||
| @ms_function | |||
| def while_in_while(x, y, z): | |||
| out = c4 | |||
| while x < y: | |||
| z = c4 + c4 | |||
| while z < y: | |||
| z = z + 1 | |||
| out = out + z | |||
| x = x + 1 | |||
| out = out + x | |||
| return out | |||
| @ms_function | |||
| def while_by_while_in_while(x, y, z): | |||
| out = c4 | |||
| while x < c2: | |||
| y = c4 + c4 | |||
| while y < c2: | |||
| y = y + 1 | |||
| out = out + y | |||
| z = c4 + c4 | |||
| while z < c2: | |||
| z = z + 1 | |||
| out = out + z | |||
| x = x + 1 | |||
| out = out + x | |||
| return out | |||
| @ms_function | |||
| def while_in_while_in_while(x, y, z): | |||
| out = c4 | |||
| while x < c2: | |||
| y = c4 + c4 | |||
| while y < c2: | |||
| y = y + 1 | |||
| z = c4 + c4 | |||
| while z < c2: | |||
| z = z + 1 | |||
| out = out + z | |||
| out = out + y | |||
| x = x + 1 | |||
| out = out + x | |||
| return out | |||
| def test_simple_if(): | |||
| output = simple_if(c1, c2, c3) | |||
| expect = Tensor([6], mstype.int32) | |||
| assert output == expect | |||
| def test_if_by_if(): | |||
| output = if_by_if(c1, c2, c3) | |||
| expect = Tensor([8], mstype.int32) | |||
| assert output == expect | |||
| def test_if_in_if(): | |||
| output = if_in_if(c1, c2, c3) | |||
| expect = Tensor([7], mstype.int32) | |||
| assert output == expect | |||
| def test_simple_while(): | |||
| output = simple_while(c1, c2, c3) | |||
| expect = Tensor([21], mstype.int32) | |||
| assert output == expect | |||
| def test_while_by_while(): | |||
| output = while_by_while(c1, c2, c3) | |||
| expect = Tensor([28], mstype.int32) | |||
| assert output == expect | |||
| def test_while_in_while(): | |||
| output = while_in_while(c1, c2, c3) | |||
| expect = Tensor([1274], mstype.int32) | |||
| assert output == expect | |||
| def test_while_by_while_in_while(): | |||
| output = while_by_while_in_while(c1, c2, c3) | |||
| expect = Tensor([350], mstype.int32) | |||
| assert output == expect | |||
| def test_while_in_while_in_while(): | |||
| output = while_in_while_in_while(c1, c2, c3) | |||
| expect = Tensor([2534], mstype.int32) | |||
| assert output == expect | |||
| @@ -0,0 +1,119 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_multigraph_sink """ | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common import ms_function | |||
| from mindspore.ops import operations as P | |||
| def setup_module(module): | |||
| context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") | |||
| context.set_context(enable_task_sink = True, device_id = 0) | |||
| c1 = Tensor([2], mstype.int32) | |||
| c2 = Tensor([14], mstype.int32) | |||
| c3 = Tensor([1], mstype.int32) | |||
| c4 = Tensor([0], mstype.int32) | |||
| c5 = Tensor([14], mstype.int32) | |||
| @ms_function | |||
| def simple_if(x, y, z): | |||
| if x < y: | |||
| x = x + 1 | |||
| else: | |||
| x = x + 2 | |||
| x = x + 3 | |||
| return x | |||
| @ms_function | |||
| def if_by_if(x, y, z): | |||
| if x < y: | |||
| x = x + 1 | |||
| if y > x: | |||
| x = x + 2 | |||
| x = x + 3 | |||
| return x | |||
| @ms_function | |||
| def if_in_if(x, y, z): | |||
| out = c4 | |||
| if x < y: | |||
| z = c4 + c4 | |||
| if z < y: | |||
| z = z + 2 | |||
| out = out + z | |||
| x = x + 3 | |||
| out = out + x | |||
| return out | |||
| @ms_function | |||
| def simple_while(x, y, z): | |||
| y = y + 4 | |||
| while x < y: | |||
| x = x + 1 | |||
| x = x + 3 | |||
| return x | |||
| @ms_function | |||
| def while_by_while(x, y, z): | |||
| while x < y: | |||
| x = x + 1 | |||
| while z < c5: | |||
| z = z + 1 | |||
| x = x + 1 | |||
| x = x + 1 | |||
| return x | |||
| def test_simple_if(): | |||
| output = simple_if(c1, c2, c3) | |||
| expect = Tensor([6], mstype.int32) | |||
| assert output == expect | |||
| def test_if_by_if(): | |||
| output = if_by_if(c1, c2, c3) | |||
| expect = Tensor([8], mstype.int32) | |||
| assert output == expect | |||
| def test_if_in_if(): | |||
| output = if_in_if(c1, c2, c3) | |||
| expect = Tensor([7], mstype.int32) | |||
| assert output == expect | |||
| def test_simple_while(): | |||
| output = simple_while(c1, c2, c3) | |||
| expect = Tensor([21], mstype.int32) | |||
| assert output == expect | |||
| def test_while_by_while(): | |||
| output = while_by_while(c1, c2, c3) | |||
| expect = Tensor([28], mstype.int32) | |||
| assert output == expect | |||