/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "vm/vmimpl.h" #include #include #include #include #include "frontend/operator/ops.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" #include "utils/convert_utils.h" #include "utils/primitive_utils.h" namespace mindspore { namespace compile { // Indicate a call to a new frame. struct CallWrap : public Base { explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {} VMFramePtr frame{nullptr}; }; using CallWrapPtr = std::shared_ptr; // Indicates a return with its value. struct ReturnWrap : public Base { explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {} BaseRef value{BaseRef()}; }; using ReturnWrapPtr = std::shared_ptr; VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure) : values_(values), todo_(nodes), closure_(closure) { std::reverse(std::begin(todo_), std::end(todo_)); } const BaseRef VMFrame::operator[](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto ret = values_.find(node); if (ret != values_.end()) { return ret->second; } ret = closure_.find(node); if (ret != closure_.end()) { return ret->second; } if (node->isa()) { return GetValueNode(node); } MS_LOG(EXCEPTION) << "ValueError " << node->type_name(); } Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values) : func_graph_(graph), values_(values) {} BaseRef Closure::operator()(const VectorRef &args) { MS_LOG(DEBUG) << "start closure"; return vm_->Evaluate(func_graph_, args, values_); } Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {} BaseRef Partial::operator()(const VectorRef &nodes) { VectorRef arglist; (void)arglist.insert(arglist.end(), args_.begin(), args_.end()); (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end()); return vm_->Call(fn_, arglist); } SetRef VM::ComputeFvs(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); SetRef rval; for (auto &fkv : graph->free_variables_total()) { if (utils::isa(fkv.first)) { // Add all value_nodes of g that refer to a fv graph auto g = utils::cast(fkv.first); for (auto &ctkv : g->value_nodes()) { auto ct = ctkv.first; if (GetValueNode(ct) == g) { (void)rval.insert(ct); } } } else { // Add a normal fv (void)rval.insert(fkv.first); } } return rval; } void VM::AcquireGraph(const FuncGraphPtr &graph) { // Already acquired if (vars_.find(graph) != vars_.end()) { return; } // Add g to manager manager_->AddFuncGraph(graph); // Compute fvs for all acquired graph auto graphs = graph->manager()->func_graphs(); for (auto g = graphs.begin(); g != graphs.end(); ++g) { vars_[*g] = ComputeFvs(*g); } } VectorRef VM::ExportSequence(const VectorRef &seq) { std::vector ret; (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret), [&, this](const BaseRef &x) -> BaseRef { return Export(x); }); return VectorRef(ret); } ClosurePtr VM::ExportClosure(const ClosurePtr &clos) { MS_EXCEPTION_IF_NULL(clos); clos->set_vm(shared_from_this()); return clos; } // transform graph to executable closure ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) { auto c = std::make_shared(g, AnfNodePtrToBaseRefMap()); MS_EXCEPTION_IF_NULL(c); c->set_vm(shared_from_this()); return c; } BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; } BaseRef VM::Export(const BaseRef &value) { if (utils::isa(value) && utils::cast(value)->isa()) { return ExportGraph(utils::cast(value)->cast()); } if (utils::isa(value) && utils::cast(value)->isa()) { return ExportPrimitive(utils::cast(value)->cast()); } if (utils::isa(value)) { return ExportGraph(utils::cast(value)); } if (utils::isa(value)) { return ExportClosure(utils::cast(value)); } if (utils::isa(value)) { return ExportPrimitive(utils::cast(value)); } if (utils::isa(value)) { return ExportSequence(utils::cast(value)); } return ExportObj(value); } // Run a graph. // This will evaluate the passed-in graph and return the resulting value. BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) { AcquireGraph(graph); MS_LOG(DEBUG) << "evalue arg size: " << args.size(); if (args.size() != graph->parameters().size()) { MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graph->parameters().size() << ", but got " << args.size(); } // toposort graph nodes, the order will be reversed by frame so that the dependent be computed first auto nodes = TopoSort(graph->get_return(), SuccVm(graph)); // mapping parameters to args AnfNodePtrToBaseRefMap values; for (size_t i = 0; i < args.size(); i++) { values[graph->parameters()[i]] = args[i]; } // create top frame with params initialized VMFramePtrList frames{std::make_shared(nodes, values, closure)}; // execute frames starting from top frame while (!frames.empty()) { auto frame = frames[frames.size() - 1]; auto todo = frame->todo(); while (!todo.empty()) { auto except = HandleNode(todo[todo.size() - 1], frame); if (utils::isa(except)) { if (todo.size() == 2) { // The last element is always a return, replace the ret with call frame frames[frames.size() - 1] = utils::cast(except)->frame; } else { frames.push_back(utils::cast(except)->frame); } break; } if (utils::isa(except)) { (void)frames.erase(frames.begin() + (static_cast(frames.size()) - 1)); if (frames.size() > 0) { auto top = frames[frames.size() - 1]; auto td = top->todo(); // set value for top frame's last evaluated node if (td.empty()) { MS_LOG(EXCEPTION) << "The td is empty"; } top->values()[td[td.size() - 1]] = utils::cast(except)->value; (void)td.erase(td.begin() + (static_cast(td.size()) - 1)); } else { return Export(utils::cast(except)->value); } break; } (void)todo.erase(todo.begin() + (static_cast(todo.size()) - 1)); } } MS_LOG(EXCEPTION) << "VM Evaluate error"; } SuccFunc VM::SuccVm(const FuncGraphPtr &graph) { auto fn = [&, this](const AnfNodePtr &node) -> AnfNodePtrList { MS_EXCEPTION_IF_NULL(node); AnfNodePtrList ret; // Follow node.incoming if (node->isa()) { auto &inputs = node->cast()->inputs(); for (auto &i : inputs) { if (i->func_graph() == node->func_graph() || (IsValueNode(i) && GetValueNode(i)->parent() == graph)) { ret.push_back(i); } } } // for subgraph input, add their fvs as succ nodes if (IsValueNode(node) && GetValueNode(node)->parent() == graph) { auto fvs = utils::cast(vars_[GetValueNode(node)]); (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(ret), [](const BaseRef &value) -> AnfNodePtr { return utils::cast(value); }); } return ret; }; return fn; } BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn)) { return RunOperation(utils::cast(fn), args); } if (utils::isa(fn)) { return Evaluate(utils::cast(fn), args); } if (utils::isa(fn)) { auto clos = utils::cast(fn); return Evaluate(clos->func_graph(), args, clos->values()); } MS_LOG(EXCEPTION) << "Can't call fn"; } // make call frame for graph BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) { AnfNodePtrToBaseRefMap clos; auto func_graph = graph; if (utils::isa(func_graph)) { clos = utils::cast(func_graph)->values(); func_graph = utils::cast(func_graph)->func_graph(); } if (utils::isa(func_graph)) { func_graph = utils::cast(func_graph)->cast(); } if (!utils::isa(func_graph)) { MS_LOG(EXCEPTION) << "Graph type error"; } auto graphPtr = utils::cast(func_graph); if (vars_.find(graphPtr) == vars_.end()) { AcquireGraph(graphPtr); } if (args.size() != graphPtr->parameters().size()) { MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graphPtr->parameters().size() << ", but got " << args.size(); } auto nodes = TopoSort(graphPtr->get_return(), SuccVm(graphPtr)); AnfNodePtrToBaseRefMap values; for (size_t i = 0; i < args.size(); i++) { values[graphPtr->parameters()[i]] = args[i]; } return std::make_shared(std::make_shared(nodes, values, clos)); } // make closure out of graph with fv values from frame ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(frame); AnfNodePtrToBaseRefMap clos; for (auto &v : utils::cast(vars_[graph])) { auto anf = utils::cast(v); clos[anf] = (*frame)[anf]; } return std::make_shared(graph, clos); } BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn) && utils::cast(fn)->isa()) { auto fnval = utils::cast(fn)->cast(); MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true); if (args.empty()) { MS_LOG(EXCEPTION) << "args is empty"; } if (fnval == prim::kPrimReturn) { MS_LOG(DEBUG) << "return args:" << args.size(); return std::make_shared(args[0]); } if (fnval == prim::kPrimMakeTuple) { frame->values()[node] = args; return BaseRef(); } if (fnval == prim::kPrimPartial) { VectorRef partial_args(args.begin() + 1, args.end()); frame->values()[node] = (std::make_shared(args[0], partial_args, shared_from_this())); return BaseRef(); } // call prim implementation frame->values()[node] = RunOperation(fnval, args); return BaseRef(); } // partial args logic if (utils::isa(fn)) { auto fnPtr = utils::cast(fn); VectorRef arglist; (void)arglist.insert(arglist.end(), fnPtr->args().begin(), fnPtr->args().end()); (void)arglist.insert(arglist.end(), args.begin(), args.end()); auto ret = DispatchCall(node, frame, fnPtr->fn(), arglist); if (utils::isa(ret) || utils::isa(ret)) { return ret; } } // create frame for graph and closure if ((utils::isa(fn) && utils::cast(fn)->isa()) || utils::isa(fn)) { auto ret = _Call(fn, args); if (utils::isa(ret) || utils::isa(ret)) { return ret; } } MS_LOG(EXCEPTION) << "Invalid fn to call"; } BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { // pass return BaseRef(); } if (node->isa()) { // We only visit valuenode graphs if (!IsValueNode(node)) { MS_LOG(EXCEPTION) << "We only visit valuenode graphs "; } auto g = GetValueNode(node); // if g is a graph with fvs, we need to make a closure for it auto iterG = vars_.find(g); if (iterG != vars_.end() && utils::cast(iterG->second).size() != 0) { frame->values()[node] = MakeClosure(g, frame); } return BaseRef(); } if (node->isa()) { std::vector fnArgs; auto &inputs = node->cast()->inputs(); // set args' values in frame (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs), [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; }); if (fnArgs.empty()) { MS_LOG(EXCEPTION) << "function arguments is empty"; } else { auto args = VectorRef(fnArgs.begin() + 1, fnArgs.end()); auto except = DispatchCall(node, frame, fnArgs[0], args); return except; } } MS_LOG(EXCEPTION) << "Unknown node type"; } VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) { this->manager_ = Manage(g); auto fn = utils::cast(Export(g)); auto result = (*fn)(args); if (utils::isa(result)) { return utils::cast(result); } else { VectorRef ret({result}); return ret; } } BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { MS_LOG(DEBUG) << "operation start " << prim->name(); MS_EXCEPTION_IF_NULL(prim); auto result = prim->RunComputeFunction(args); if (result.is_null()) { return RunComputeFunction(prim, args); } return result; } } // namespace compile } // namespace mindspore