Merge pull request !967 from panyifeng/switch_case_primitivetags/v0.3.0-alpha
| @@ -59,6 +59,7 @@ const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||
| // Statements | |||
| const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch"); | |||
| const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer"); | |||
| const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||
| const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | |||
| const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | |||
| @@ -65,6 +65,7 @@ extern const PrimitivePtr kPrimHasType; | |||
| // Statements | |||
| extern const PrimitivePtr kPrimSwitch; | |||
| extern const PrimitivePtr kPrimSwitchLayer; | |||
| extern const PrimitivePtr kPrimReturn; | |||
| extern const PrimitivePtr kPrimAssign; | |||
| extern const PrimitivePtr kPrimAssignAdd; | |||
| @@ -126,6 +126,30 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString(); | |||
| } | |||
| AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: index, branch | |||
| if (args_spec_list.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is " | |||
| << args_spec_list.size() << "."; | |||
| } | |||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(primitive->name(), args_spec_list, 1); | |||
| AbstractBasePtrList branches = branches_abs->elements(); | |||
| const size_t maximum_layer_num = 1000; | |||
| if (branches.size() < 0 || branches.size() > maximum_layer_num) { | |||
| MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got " | |||
| << branches.size() << " branches."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(branches[0]); | |||
| auto b = branches[0]; | |||
| for (size_t i = 1; i < branches.size(); i++) { | |||
| MS_EXCEPTION_IF_NULL(branches[i]); | |||
| b = b->Join(branches[i]); | |||
| } | |||
| return b; | |||
| } | |||
| std::vector<ValuePtr> GetSupportedTargetValue() { | |||
| std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)}; | |||
| return list; | |||
| @@ -38,6 +38,7 @@ namespace mindspore { | |||
| namespace ad { | |||
| std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_; | |||
| std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_; | |||
| FuncGraphSet DFunctor::scope_; | |||
| DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) | |||
| : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { | |||
| @@ -55,11 +56,15 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas | |||
| void DFunctor::Init(const DFunctorPtr &functor, bool is_top) { | |||
| func_graph_to_functor_[primal_graph_] = functor; | |||
| is_top_ = is_top; | |||
| if (is_top) { | |||
| scope_ = primal_graph_->scope(); | |||
| } | |||
| } | |||
| void DFunctor::Clear() { | |||
| func_graph_to_functor_.clear(); | |||
| anfnode_to_adjoin_definition_.clear(); | |||
| scope_.clear(); | |||
| } | |||
| void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { | |||
| @@ -95,11 +100,48 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { | |||
| fv_adjoint->second->AccumulateDout(dfv); | |||
| } | |||
| void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) { | |||
| // Take switch_layer as a set of candidate functions. | |||
| auto input = cnode_morph->input(2); | |||
| if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { | |||
| MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; | |||
| } | |||
| auto tuple_graphs = input->cast<CNodePtr>(); | |||
| for (size_t i = 1; i < tuple_graphs->size(); ++i) { | |||
| auto graph = tuple_graphs->input(i); | |||
| if (!IsValueNode<FuncGraph>(graph)) { | |||
| MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString() | |||
| << " as the " << i << "th element."; | |||
| } | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(graph); | |||
| auto functor = func_graph_to_functor_.find(func_graph); | |||
| if (functor == func_graph_to_functor_.end()) { | |||
| MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] " | |||
| << func_graph->ToString() << "."; | |||
| } | |||
| // Consider direct and indirect fvs. | |||
| for (auto fv : func_graph->free_variables_nodes()) { | |||
| BackPropagateFv(fv, env); | |||
| } | |||
| for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { | |||
| MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " | |||
| << indirect_fv.first->ToString() << "."; | |||
| BackPropagateFv(indirect_fv.first, env); | |||
| } | |||
| } | |||
| } | |||
| void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) { | |||
| auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)}); | |||
| // Call with delimited continuation dout. | |||
| auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); | |||
| node_adjoint->RegisterDoutUser(bprop_app, 1); | |||
| // Special case for switch_layer | |||
| if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) { | |||
| auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)}); | |||
| BackPropagateSwitchLayer(cnode_morph, din); | |||
| return; | |||
| } | |||
| for (size_t i = 0; i < cnode_morph->size(); i++) { | |||
| auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))}); | |||
| auto input = cnode_morph->input(i); | |||
| @@ -402,6 +444,11 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { | |||
| return primal; | |||
| } | |||
| bool DFunctor::IsInScope(const AnfNodePtr &node) { | |||
| return std::any_of(scope_.begin(), scope_.end(), | |||
| [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; }); | |||
| } | |||
| void DFunctor::MapFvObject() { | |||
| // Map free variable. | |||
| const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); | |||
| @@ -414,8 +461,8 @@ void DFunctor::MapFvObject() { | |||
| if (parent_adjoint != nullptr) { | |||
| adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_); | |||
| } else { | |||
| if (is_top_) { | |||
| // Top graph for ad, add adjoint for free variables. | |||
| if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) { | |||
| // Out of ad scope, add adjoint for free variables. | |||
| adjoint = std::make_shared<Adjoint>(node, node, tape_); | |||
| UpdateAdjoint(adjoint); | |||
| } else { | |||
| @@ -62,9 +62,11 @@ class DFunctor { | |||
| // Map one morphism. | |||
| AdjointPtr MapMorphism(const AnfNodePtr &morph); | |||
| bool IsFreeMorphism(const AnfNodePtr &node); | |||
| bool IsInScope(const AnfNodePtr &node); | |||
| // Map morphism that's not attached to output. | |||
| void MapFreeMorphism(); | |||
| void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); | |||
| void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env); | |||
| void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); | |||
| AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); | |||
| AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); | |||
| @@ -101,6 +103,7 @@ class DFunctor { | |||
| bool is_top_; | |||
| static std::unordered_map<FuncGraphPtr, std::shared_ptr<DFunctor>> func_graph_to_functor_; | |||
| static std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_definition_; | |||
| static FuncGraphSet scope_; | |||
| }; | |||
| // D Functor's rules to map primitive object. | |||
| @@ -120,6 +123,7 @@ class KPrim { | |||
| private: | |||
| FuncGraphPtr GetBprop(const PrimitivePtr &prim); | |||
| FuncGraphPtr GetFprop(const PrimitivePtr &prim); | |||
| FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||
| // Given a bprop rule, do the K mapping. | |||
| template <typename T> | |||
| @@ -62,6 +62,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { | |||
| return func_graph; | |||
| } | |||
| FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { | |||
| static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; | |||
| std::string func_name = "_fprop_" + prim->name(); | |||
| py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); | |||
| auto func_graph = parse::ParsePythonCode(fn); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| return BasicClone(func_graph); | |||
| } | |||
| MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| @@ -92,6 +101,13 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R | |||
| return iter->second; | |||
| } | |||
| if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") { | |||
| auto fprop = GetFprop(prim); | |||
| fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); | |||
| bprop_registry_[prim::kPrimSwitchLayer] = fprop; | |||
| return fprop; | |||
| } | |||
| if (prim->name() == "make_tuple") { | |||
| return nullptr; | |||
| } | |||
| @@ -50,6 +50,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimHasType, {InferImplHasType, false}}, | |||
| {prim::kPrimDot, {InferImplDot, true}}, | |||
| {prim::kPrimSwitch, {InferImplSwitch, true}}, | |||
| {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, | |||
| {prim::kPrimIs_, {InferImplIs_, true}}, | |||
| {prim::kPrimIsNot, {InferImplIsNot, true}}, | |||
| {prim::kPrimInDict, {InferImplInDict, true}}, | |||
| @@ -174,6 +174,8 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| @@ -242,3 +242,9 @@ def bprop_switch(cond, tb, fb, out, dout): | |||
| """Backpropagator for primitive `switch`.""" | |||
| return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ | |||
| F.switch(cond, C.zeros_like(fb), dout) | |||
| def _fprop_switch_layer(index, layers): | |||
| """Backpropagator for primitive `switch_layer`.""" | |||
| def _bprop_switch_layer(dout): | |||
| return dout, C.zeros_like(index), () | |||
| return F.switch_layer(index, layers), _bprop_switch_layer | |||
| @@ -135,6 +135,7 @@ env_getitem = Primitive('env_getitem') | |||
| env_add = Primitive('env_add') | |||
| J = Primitive('J') | |||
| switch = Primitive('switch') | |||
| switch_layer = Primitive('switch_layer') | |||
| # for sum bprop | |||
| reduced_shape = Primitive("reduced_shape") | |||
| # shape_mul:input mush be shape multiply elemts in tuple(shape) | |||
| @@ -19,6 +19,9 @@ from mindspore import nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -358,3 +361,33 @@ def test_if_compile_true(): | |||
| def test_if_compile_false(): | |||
| output = if_compile_test(8, 3) | |||
| print("test_if_compile_false:", output) | |||
| def test_switch_layer(): | |||
| class Layer1(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer1, self).__init__() | |||
| self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') | |||
| def construct(self, x): | |||
| return x * self.z1 | |||
| class Layer2(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer2, self).__init__() | |||
| self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') | |||
| def construct(self, x): | |||
| return x * self.z2 | |||
| class SwitchLayerCell(nn.Cell): | |||
| def __init__(self): | |||
| super(SwitchLayerCell, self).__init__() | |||
| self.layers = (Layer1(), Layer2()) | |||
| self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||
| def construct(self, index, x): | |||
| ret = F.switch_layer(index, self.layers)(x) * self.z3 | |||
| return ret | |||
| net = SwitchLayerCell() | |||
| net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||