Merge pull request !2895 from fanglei/casetags/v0.6.0-beta
| @@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool BaseRefToInt(const ValuePtr &v, int *value) { | |||||
| MS_EXCEPTION_IF_NULL(v); | |||||
| if (v->isa<tensor::Tensor>()) { | |||||
| auto tensor = v->cast<tensor::TensorPtr>(); | |||||
| (void)tensor->data_sync(); | |||||
| int *tensor_data = static_cast<int *>(tensor->data_c()); | |||||
| auto vb = tensor_data[0]; | |||||
| *value = vb; | |||||
| return true; | |||||
| } | |||||
| MS_LOG(ERROR) << "Index must be tensor type."; | |||||
| return false; | |||||
| } | |||||
| bool BaseRefToBool(const BaseRef &v, bool *value) { | bool BaseRefToBool(const BaseRef &v, bool *value) { | ||||
| if (utils::isa<ValuePtr>(v)) { | if (utils::isa<ValuePtr>(v)) { | ||||
| return ValueToBool(utils::cast<ValuePtr>(v), value); | return ValueToBool(utils::cast<ValuePtr>(v), value); | ||||
| @@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr<Tensor>; | |||||
| py::object AnyToPyData(const Any &value); | py::object AnyToPyData(const Any &value); | ||||
| py::object BaseRefToPyData(const BaseRef &value); | py::object BaseRefToPyData(const BaseRef &value); | ||||
| bool BaseRefToBool(const BaseRef &in, bool *out); | bool BaseRefToBool(const BaseRef &in, bool *out); | ||||
| bool BaseRefToInt(const ValuePtr &v, int *value); | |||||
| bool ValueToBool(const ValuePtr &in, bool *out); | bool ValueToBool(const ValuePtr &in, bool *out); | ||||
| py::object ValuePtrToPyData(const ValuePtr &value); | py::object ValuePtrToPyData(const ValuePtr &value); | ||||
| @@ -32,6 +32,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace compile { | namespace compile { | ||||
| bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } | bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } | ||||
| bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); } | |||||
| LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { | LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { | ||||
| // multi_graph merge to one, big graph have paramters in begin and only have one output | // multi_graph merge to one, big graph have paramters in begin and only have one output | ||||
| @@ -46,6 +46,7 @@ class Backend { | |||||
| virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} | virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} | ||||
| virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } | virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } | ||||
| virtual bool GetCond(const BaseRef &c, bool *value); | virtual bool GetCond(const BaseRef &c, bool *value); | ||||
| virtual bool GetIndex(const BaseRef &c, int *value); | |||||
| virtual void SetSwitchGraph() {} | virtual void SetSwitchGraph() {} | ||||
| virtual void SetSwitchActive(const BaseRef &, bool) {} | virtual void SetSwitchActive(const BaseRef &, bool) {} | ||||
| virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} | virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} | ||||
| @@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv | |||||
| std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | ||||
| prim::kPrimMakeTuple, prim::kPrimBpropCut}; | prim::kPrimMakeTuple, prim::kPrimBpropCut}; | ||||
| const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | ||||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | |||||
| prim::kPrimBpropCut}; | |||||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, | |||||
| prim::kPrimSwitch, prim::kPrimMakeTuple, | |||||
| prim::kPrimBpropCut, prim::kPrimSwitchLayer}; | |||||
| return ms_nonlinear_ops; | return ms_nonlinear_ops; | ||||
| } | } | ||||
| @@ -187,6 +188,30 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||||
| std::reverse(result.begin(), result.end()); | std::reverse(result.begin(), result.end()); | ||||
| return result; | return result; | ||||
| } | } | ||||
| bool IsSubGraph(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| AnfNodePtr fn = inputs[0]; | |||||
| MS_EXCEPTION_IF_NULL(fn); | |||||
| if (!IsValueNode<Primitive>(fn)) { | |||||
| return false; | |||||
| } | |||||
| auto node_prim = GetValueNode<PrimitivePtr>(fn); | |||||
| if (node_prim->name() == prim::kPrimPartial->name()) { | |||||
| return true; | |||||
| } | |||||
| } else if (IsValueNode<FuncGraph>(node)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | ||||
| @@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| ms_context->set_enable_pynative_hook(true); | ms_context->set_enable_pynative_hook(true); | ||||
| } | } | ||||
| if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { | |||||
| if (inputs.size() < 2) { | |||||
| return false; | |||||
| } | |||||
| auto ret = IsSubGraph(inputs[1]); | |||||
| return ret; | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) | |||||
| } else if (IsPrimitive(fn, prim::kPrimSwitch)) { | } else if (IsPrimitive(fn, prim::kPrimSwitch)) { | ||||
| AddSwitch(node); | AddSwitch(node); | ||||
| AddSinkSwitch(node); | AddSinkSwitch(node); | ||||
| } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) { | |||||
| AddSwitchLayer(node); | |||||
| } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { | } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { | ||||
| AddMakeTuple(node); | AddMakeTuple(node); | ||||
| } else { | } else { | ||||
| @@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) { | |||||
| AddInst(Instruction::kSwitch, args); | AddInst(Instruction::kSwitch, args); | ||||
| } | } | ||||
| void CompileGraph::AddSwitchLayer(const CNodePtr &node) { | |||||
| auto inputs = node->inputs(); | |||||
| if (inputs.size() != 3) { | |||||
| MS_LOG(EXCEPTION) << "Switch layer must have index and branches."; | |||||
| } | |||||
| VectorRef args; | |||||
| args.emplace_back(Ref(inputs[1])); | |||||
| args.emplace_back(Ref(inputs[2])); | |||||
| AddInst(Instruction::kSwitchLayer, args); | |||||
| } | |||||
| void CompileGraph::AddReturn(const CNodePtr &node) { | void CompileGraph::AddReturn(const CNodePtr &node) { | ||||
| VectorRef args; | VectorRef args; | ||||
| if (backend_->simu_flag()) { | if (backend_->simu_flag()) { | ||||
| @@ -90,6 +90,7 @@ class CompileGraph { | |||||
| void AddPartial(const CNodePtr &node); | void AddPartial(const CNodePtr &node); | ||||
| void AddMakeTuple(const CNodePtr &node); | void AddMakeTuple(const CNodePtr &node); | ||||
| void AddSwitch(const CNodePtr &node); | void AddSwitch(const CNodePtr &node); | ||||
| void AddSwitchLayer(const CNodePtr &node); | |||||
| void AddReturn(const CNodePtr &node); | void AddReturn(const CNodePtr &node); | ||||
| void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); | void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); | ||||
| void AddInput(const AnfNodePtr &node); | void AddInput(const AnfNodePtr &node); | ||||
| @@ -480,6 +480,35 @@ void FinalVM::InstSwitch(const VectorRef &args) { | |||||
| MS_LOG(DEBUG) << "End"; | MS_LOG(DEBUG) << "End"; | ||||
| } | } | ||||
| void FinalVM::InstSwitchLayer(const VectorRef &args) { | |||||
| MS_LOG(DEBUG) << "Start"; | |||||
| const size_t args_size = 2; | |||||
| if (args.size() != args_size) { | |||||
| MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size() | |||||
| << "."; | |||||
| return; | |||||
| } | |||||
| int idx = utils::cast<int>(args[0]); | |||||
| VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int>(args[1]))); | |||||
| int size = static_cast<int>(branches.size()); | |||||
| BaseRef index = Ref(idx); | |||||
| int idx_value = 0; | |||||
| if (!backend_->GetIndex(index, &idx_value)) { | |||||
| MS_LOG(EXCEPTION) << "Not supported type to be casted to int."; | |||||
| } | |||||
| if (idx_value < 0) { | |||||
| // Add support negative index range [-size, -1]. | |||||
| idx_value += size; | |||||
| } | |||||
| if (idx_value < 0 || idx_value >= size) { | |||||
| MS_LOG(EXCEPTION) << __FUNCTION__ << " given index " << idx_value << " out of range."; | |||||
| } | |||||
| Push(branches[idx_value]); | |||||
| MS_LOG(DEBUG) << "End"; | |||||
| } | |||||
| void FinalVM::InstTuple(const VectorRef &args) { | void FinalVM::InstTuple(const VectorRef &args) { | ||||
| MS_LOG(DEBUG) << "Start"; | MS_LOG(DEBUG) << "Start"; | ||||
| VectorRef tuple; | VectorRef tuple; | ||||
| @@ -51,15 +51,17 @@ enum Instruction { | |||||
| kPush, | kPush, | ||||
| kPrim, | kPrim, | ||||
| kGraph, | kGraph, | ||||
| kPadStack | |||||
| kPadStack, | |||||
| kSwitchLayer | |||||
| }; | }; | ||||
| using InstType = std::pair<Instruction, VectorRef>; | using InstType = std::pair<Instruction, VectorRef>; | ||||
| using InstSet = std::vector<InstType>; | using InstSet = std::vector<InstType>; | ||||
| using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; | using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; | ||||
| const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", | |||||
| "input", "external", "push", "primitive", "graph", "pad_stack"}; | |||||
| const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial", "switch", | |||||
| "switch_return", "tuple", "input", "external", "push", | |||||
| "primitive", "graph", "pad_stack", "switch_layer"}; | |||||
| class StructPartial : public Base { | class StructPartial : public Base { | ||||
| public: | public: | ||||
| // Initialize StructPartial. | // Initialize StructPartial. | ||||
| @@ -114,6 +116,7 @@ class FinalVM { | |||||
| void InstExternal(const VectorRef &args); | void InstExternal(const VectorRef &args); | ||||
| void InstPushPrim(const VectorRef &args); | void InstPushPrim(const VectorRef &args); | ||||
| void InstSwitchReturn(const VectorRef &args); | void InstSwitchReturn(const VectorRef &args); | ||||
| void InstSwitchLayer(const VectorRef &args); | |||||
| void set_insts(const InstSet &value) { insts_ = value; } | void set_insts(const InstSet &value) { insts_ = value; } | ||||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); | BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); | ||||
| @@ -157,7 +160,7 @@ class FinalVM { | |||||
| {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, | {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, | ||||
| {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, | {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, | ||||
| {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, | {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, | ||||
| }; | |||||
| {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; | |||||
| std::map<std::string, py::object> _hook_grad; | std::map<std::string, py::object> _hook_grad; | ||||
| }; | }; | ||||