Merge pull request !2895 from fanglei/casetags/v0.6.0-beta
| @@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) { | |||
| return true; | |||
| } | |||
| bool BaseRefToInt(const ValuePtr &v, int *value) { | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| if (v->isa<tensor::Tensor>()) { | |||
| auto tensor = v->cast<tensor::TensorPtr>(); | |||
| (void)tensor->data_sync(); | |||
| int *tensor_data = static_cast<int *>(tensor->data_c()); | |||
| auto vb = tensor_data[0]; | |||
| *value = vb; | |||
| return true; | |||
| } | |||
| MS_LOG(ERROR) << "Index must be tensor type."; | |||
| return false; | |||
| } | |||
| bool BaseRefToBool(const BaseRef &v, bool *value) { | |||
| if (utils::isa<ValuePtr>(v)) { | |||
| return ValueToBool(utils::cast<ValuePtr>(v), value); | |||
| @@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr<Tensor>; | |||
| py::object AnyToPyData(const Any &value); | |||
| py::object BaseRefToPyData(const BaseRef &value); | |||
| bool BaseRefToBool(const BaseRef &in, bool *out); | |||
| bool BaseRefToInt(const ValuePtr &v, int *value); | |||
| bool ValueToBool(const ValuePtr &in, bool *out); | |||
| py::object ValuePtrToPyData(const ValuePtr &value); | |||
| @@ -32,6 +32,7 @@ | |||
| namespace mindspore { | |||
| namespace compile { | |||
| bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } | |||
| bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); } | |||
| LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { | |||
| // multi_graph merge to one, big graph have paramters in begin and only have one output | |||
| @@ -46,6 +46,7 @@ class Backend { | |||
| virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} | |||
| virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } | |||
| virtual bool GetCond(const BaseRef &c, bool *value); | |||
| virtual bool GetIndex(const BaseRef &c, int *value); | |||
| virtual void SetSwitchGraph() {} | |||
| virtual void SetSwitchActive(const BaseRef &, bool) {} | |||
| virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} | |||
| @@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv | |||
| std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | |||
| prim::kPrimMakeTuple, prim::kPrimBpropCut}; | |||
| const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | |||
| prim::kPrimBpropCut}; | |||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, | |||
| prim::kPrimSwitch, prim::kPrimMakeTuple, | |||
| prim::kPrimBpropCut, prim::kPrimSwitchLayer}; | |||
| return ms_nonlinear_ops; | |||
| } | |||
| @@ -187,6 +188,30 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||
| std::reverse(result.begin(), result.end()); | |||
| return result; | |||
| } | |||
| bool IsSubGraph(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr fn = inputs[0]; | |||
| MS_EXCEPTION_IF_NULL(fn); | |||
| if (!IsValueNode<Primitive>(fn)) { | |||
| return false; | |||
| } | |||
| auto node_prim = GetValueNode<PrimitivePtr>(fn); | |||
| if (node_prim->name() == prim::kPrimPartial->name()) { | |||
| return true; | |||
| } | |||
| } else if (IsValueNode<FuncGraph>(node)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | |||
| @@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_enable_pynative_hook(true); | |||
| } | |||
| if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { | |||
| if (inputs.size() < 2) { | |||
| return false; | |||
| } | |||
| auto ret = IsSubGraph(inputs[1]); | |||
| return ret; | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| @@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) | |||
| } else if (IsPrimitive(fn, prim::kPrimSwitch)) { | |||
| AddSwitch(node); | |||
| AddSinkSwitch(node); | |||
| } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) { | |||
| AddSwitchLayer(node); | |||
| } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { | |||
| AddMakeTuple(node); | |||
| } else { | |||
| @@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) { | |||
| AddInst(Instruction::kSwitch, args); | |||
| } | |||
| void CompileGraph::AddSwitchLayer(const CNodePtr &node) { | |||
| auto inputs = node->inputs(); | |||
| if (inputs.size() != 3) { | |||
| MS_LOG(EXCEPTION) << "Switch layer must have index and branches."; | |||
| } | |||
| VectorRef args; | |||
| args.emplace_back(Ref(inputs[1])); | |||
| args.emplace_back(Ref(inputs[2])); | |||
| AddInst(Instruction::kSwitchLayer, args); | |||
| } | |||
| void CompileGraph::AddReturn(const CNodePtr &node) { | |||
| VectorRef args; | |||
| if (backend_->simu_flag()) { | |||
| @@ -90,6 +90,7 @@ class CompileGraph { | |||
| void AddPartial(const CNodePtr &node); | |||
| void AddMakeTuple(const CNodePtr &node); | |||
| void AddSwitch(const CNodePtr &node); | |||
| void AddSwitchLayer(const CNodePtr &node); | |||
| void AddReturn(const CNodePtr &node); | |||
| void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); | |||
| void AddInput(const AnfNodePtr &node); | |||
| @@ -480,6 +480,35 @@ void FinalVM::InstSwitch(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| void FinalVM::InstSwitchLayer(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| const size_t args_size = 2; | |||
| if (args.size() != args_size) { | |||
| MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size() | |||
| << "."; | |||
| return; | |||
| } | |||
| int idx = utils::cast<int>(args[0]); | |||
| VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int>(args[1]))); | |||
| int size = static_cast<int>(branches.size()); | |||
| BaseRef index = Ref(idx); | |||
| int idx_value = 0; | |||
| if (!backend_->GetIndex(index, &idx_value)) { | |||
| MS_LOG(EXCEPTION) << "Not supported type to be casted to int."; | |||
| } | |||
| if (idx_value < 0) { | |||
| // Add support negative index range [-size, -1]. | |||
| idx_value += size; | |||
| } | |||
| if (idx_value < 0 || idx_value >= size) { | |||
| MS_LOG(EXCEPTION) << __FUNCTION__ << " given index " << idx_value << " out of range."; | |||
| } | |||
| Push(branches[idx_value]); | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| void FinalVM::InstTuple(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| VectorRef tuple; | |||
| @@ -51,15 +51,17 @@ enum Instruction { | |||
| kPush, | |||
| kPrim, | |||
| kGraph, | |||
| kPadStack | |||
| kPadStack, | |||
| kSwitchLayer | |||
| }; | |||
| using InstType = std::pair<Instruction, VectorRef>; | |||
| using InstSet = std::vector<InstType>; | |||
| using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; | |||
| const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", | |||
| "input", "external", "push", "primitive", "graph", "pad_stack"}; | |||
| const std::vector<std::string> inst_str{"call", "tail_call", "return", "partial", "switch", | |||
| "switch_return", "tuple", "input", "external", "push", | |||
| "primitive", "graph", "pad_stack", "switch_layer"}; | |||
| class StructPartial : public Base { | |||
| public: | |||
| // Initialize StructPartial. | |||
| @@ -114,6 +116,7 @@ class FinalVM { | |||
| void InstExternal(const VectorRef &args); | |||
| void InstPushPrim(const VectorRef &args); | |||
| void InstSwitchReturn(const VectorRef &args); | |||
| void InstSwitchLayer(const VectorRef &args); | |||
| void set_insts(const InstSet &value) { insts_ = value; } | |||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); | |||
| @@ -157,7 +160,7 @@ class FinalVM { | |||
| {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, | |||
| {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, | |||
| {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, | |||
| }; | |||
| {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; | |||
| std::map<std::string, py::object> _hook_grad; | |||
| }; | |||