/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_VM_VM_H_ #define MINDSPORE_CCSRC_VM_VM_H_ #include #include #include #include #include #include #include #include #include #include "pybind11/pybind11.h" #include "ir/anf.h" #include "base/base_ref.h" namespace py = pybind11; namespace mindspore { namespace compile { class Backend; using BackendPtr = std::shared_ptr; enum Instruction { kCall = 0, kTailCall, kReturn, kPartial, kSwitch, kSwitchReturn, kTuple, kInput, kExternal, kPush, kPrim, kGraph, kPadStack, kSwitchLayer }; using InstType = std::pair; using InstSet = std::vector; using InstFunctionMap = std::map>; const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", "input", "external", "push", "primitive", "graph", "pad_stack", "switch_layer"}; class StructPartial : public Base { public: // Initialize StructPartial. StructPartial(int fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr); virtual ~StructPartial() = default; MS_DECLARE_PARENT(StructPartial, Base) int fn_; VectorRef args_; FuncGraphPtr fg_; }; std::ostream &operator<<(std::ostream &os, const StructPartial &other); bool operator==(const StructPartial &lhs, const StructPartial &rhs); class StructSimuSwitch : public Base { public: StructSimuSwitch(const BaseRef &fn, const BaseRef &value); virtual ~StructSimuSwitch() = default; MS_DECLARE_PARENT(StructSimuSwitch, Base) BaseRef fn_; BaseRef value_; }; std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other); bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs); class FinalVM { public: // Create a VM with the specified instructions and backend. explicit FinalVM(const InstSet &insts, const BackendPtr &backend); virtual ~FinalVM() = default; BaseRef Eval(const VectorRef &args); void InstCall(const VectorRef &args); void InstTailCall(const VectorRef &args); void InstReturn(const VectorRef &args); void InstPartial(const VectorRef &args); void InstRealPartial(const VectorRef &args); void InstSwitch(const VectorRef &args); void InstRealSwitch(const VectorRef &args); void InstTuple(const VectorRef &args); void InstPush(const VectorRef &args); void InstInput(const VectorRef &args); void InstPadStack(const VectorRef &args); void InstExternal(const VectorRef &args); void InstPushPrim(const VectorRef &args); void InstSwitchReturn(const VectorRef &args); void InstSwitchLayer(const VectorRef &args); void set_insts(const InstSet &value) { insts_ = value; } BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); protected: BaseRef Ref(int i); void Push(const BaseRef &v); void Pop(int n = 1); void MoveStack(int nitems, int height); void Pushp(); void Popp(); void Pushsp(); void Popsp(); void DoJmp(const BaseRef &jmp); void SyncData(const py::object &args); private: InstSet insts_; std::deque insts_stack_; std::stack retp_; std::stack retsp_; int pc_; int sp_; BackendPtr backend_; const InstFunctionMap inst_function_map = { {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }}, {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }}, {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }}, {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }}, {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }}, {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }}, {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }}, {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }}, {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; }; using FinalVMPtr = std::shared_ptr; } // namespace compile } // namespace mindspore #endif // MINDSPORE_CCSRC_VM_VM_H_