| @@ -95,6 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| incorporate_env_getitem_ = | incorporate_env_getitem_ = | ||||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | ||||
| incorporate_env_getitem_switch_layer_ = | |||||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitchLayer>(), "incorporate_env_getitem_switch_layer", | |||||
| prim::kPrimEnvGetItem); | |||||
| // Ref eliminate | // Ref eliminate | ||||
| make_ref_eliminate_ = | make_ref_eliminate_ = | ||||
| MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | ||||
| @@ -58,6 +58,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr incorporate_env_getitem_; | SubstitutionPtr incorporate_env_getitem_; | ||||
| SubstitutionPtr incorporate_env_getitem_bypass_recursive_; | SubstitutionPtr incorporate_env_getitem_bypass_recursive_; | ||||
| SubstitutionPtr incorporate_env_getitem_switch_; | SubstitutionPtr incorporate_env_getitem_switch_; | ||||
| SubstitutionPtr incorporate_env_getitem_switch_layer_; | |||||
| // Ref eliminate | // Ref eliminate | ||||
| SubstitutionPtr make_ref_eliminate_; | SubstitutionPtr make_ref_eliminate_; | ||||
| @@ -91,6 +91,69 @@ class EnvGetitemTransform { | |||||
| std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>> | std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>> | ||||
| cache_; | cache_; | ||||
| }; | }; | ||||
| class EnvGetitemTransformACrossGraph { | |||||
| public: | |||||
| EnvGetitemTransformACrossGraph() : cache_() {} | |||||
| ~EnvGetitemTransformACrossGraph() = default; | |||||
| FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { | |||||
| if (cache_.find(fg) == cache_.end()) { | |||||
| cache_[fg] = {}; | |||||
| } | |||||
| auto &cache = cache_[fg]; | |||||
| auto hash_key = std::make_pair(key, default_node); | |||||
| if (cache.find(hash_key) == cache.end()) { | |||||
| std::ostringstream ss("env", std::ostringstream::app); | |||||
| if (key->node() != nullptr) { | |||||
| ss << key->node()->ToString(); | |||||
| } | |||||
| auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str())); | |||||
| auto output_outer = new_fg_outer->output(); | |||||
| if (!IsValueNode<FuncGraph>(output_outer)) { | |||||
| MS_LOG(WARNING) << "Output of outer graph should be a func_graph"; | |||||
| return nullptr; | |||||
| } | |||||
| auto fg_inner = GetValueNode<FuncGraphPtr>(output_outer); | |||||
| auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str())); | |||||
| new_fg_outer->set_output(NewValueNode(new_fg)); | |||||
| auto env = new_fg->output(); | |||||
| while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { | |||||
| // {prim::kPrimEnvSetItem, env, symbolickey, value} | |||||
| auto &inputs = env->cast<CNodePtr>()->inputs(); | |||||
| if (inputs.size() != 4) { | |||||
| MS_LOG(WARNING) << "Input size should be 4"; | |||||
| return nullptr; | |||||
| } | |||||
| if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) { | |||||
| MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; | |||||
| return nullptr; | |||||
| } | |||||
| env = inputs[1]; | |||||
| auto value = inputs[3]; | |||||
| auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]); | |||||
| if (*key2 == *key) { | |||||
| new_fg->set_output(value); | |||||
| cache[hash_key] = new_fg_outer; | |||||
| return new_fg_outer; | |||||
| } | |||||
| } | |||||
| new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); | |||||
| cache[hash_key] = new_fg_outer; | |||||
| } | |||||
| return cache[hash_key]; | |||||
| } | |||||
| private: | |||||
| std::unordered_map<FuncGraphPtr, | |||||
| std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>> | |||||
| cache_; | |||||
| }; | |||||
| } // namespace internal | } // namespace internal | ||||
| // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y | // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y | ||||
| @@ -358,6 +421,78 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { | |||||
| bool is_match_{false}; | bool is_match_{false}; | ||||
| internal::EnvGetitemTransform env_get_item_transform_; | internal::EnvGetitemTransform env_get_item_transform_; | ||||
| }; | }; | ||||
| // {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y} | |||||
| class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { | |||||
| public: | |||||
| IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {} | |||||
| ~IncorporateEnvGetitemSwitchLayer() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| is_match_ = false; | |||||
| AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node); | |||||
| if (!is_match_ || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| // {prim::kPrimEnvGetItem, {...}, C, Y} | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto inp1 = cnode->input(1)->cast<CNodePtr>(); | |||||
| auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2)); | |||||
| auto default_v = cnode->input(3); | |||||
| // {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys} | |||||
| auto &inputs_outer = inp1->inputs(); | |||||
| if (!inputs_outer[0]->isa<CNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> args_outer; | |||||
| args_outer.insert(args_outer.end(), inputs_outer.begin() + 1, inputs_outer.end()); | |||||
| auto &input_switch_layer = inputs_outer[0]->cast<CNodePtr>()->inputs(); | |||||
| is_match_ = false; | |||||
| AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(input_switch_layer[0]); | |||||
| if (!is_match_) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> args; | |||||
| (void)args.insert(args.end(), input_switch_layer.begin() + 1, input_switch_layer.end()); | |||||
| // {prim::kPrimSwitchLayers, X, {prim::kPrimMakeTuple, G1, G2...}} | |||||
| auto sw = input_switch_layer[0]->cast<CNodePtr>(); | |||||
| std::vector<FuncGraphPtr> graphs{}; | |||||
| auto graphs_cnode = sw->input(2)->cast<CNodePtr>(); | |||||
| auto &graphs_inputs = graphs_cnode->inputs(); | |||||
| if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(graphs_inputs[1])) { | |||||
| (void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs), | |||||
| [](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); }); | |||||
| } | |||||
| if (graphs.empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto fg = node->func_graph(); | |||||
| std::vector<AnfNodePtr> layers; | |||||
| for (auto &graph : graphs) { | |||||
| auto fg_transform = env_get_item_transform_(graph, key, default_v); | |||||
| if (fg_transform == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| layers.push_back(NewValueNode(fg_transform)); | |||||
| } | |||||
| auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); | |||||
| auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitchLayer), sw->input(1), layers_node}); | |||||
| args.insert(args.begin(), new_sw); | |||||
| auto inner_call = fg->NewCNode(args); | |||||
| args_outer.insert(args_outer.begin(), inner_call); | |||||
| return fg->NewCNode(args_outer); | |||||
| } | |||||
| void Visit(const AnfNodePtr &) override { is_match_ = true; } | |||||
| private: | |||||
| bool is_match_{false}; | |||||
| internal::EnvGetitemTransformACrossGraph env_get_item_transform_; | |||||
| }; | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -72,6 +72,52 @@ class GetitemTransform { | |||||
| private: | private: | ||||
| std::unordered_map<FuncGraphPtr, std::unordered_map<int, FuncGraphPtr>> cache_; | std::unordered_map<FuncGraphPtr, std::unordered_map<int, FuncGraphPtr>> cache_; | ||||
| }; | }; | ||||
| class GetItemTransformACrossGraph { | |||||
| public: | |||||
| GetItemTransformACrossGraph() : cache_() {} | |||||
| ~GetItemTransformACrossGraph() = default; | |||||
| FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) { | |||||
| if (cache_.find(fg) == cache_.end()) { | |||||
| cache_[fg] = {}; | |||||
| } | |||||
| auto &cache = cache_[fg]; | |||||
| if (cache.find(idx) == cache.end()) { | |||||
| std::ostringstream ss("tp", std::ostringstream::app); | |||||
| ss << idx; | |||||
| auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str())); | |||||
| auto output_outer = new_fg_outer->output(); | |||||
| if (!IsValueNode<FuncGraph>(output_outer)) { | |||||
| MS_LOG(WARNING) << "Output of outer graph should be a func_graph"; | |||||
| return nullptr; | |||||
| } | |||||
| auto fg_inner = GetValueNode<FuncGraphPtr>(output_outer); | |||||
| auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str())); | |||||
| new_fg_outer->set_output(NewValueNode(new_fg)); | |||||
| auto output = new_fg->output(); | |||||
| if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||||
| auto cnode = output->cast<CNodePtr>(); | |||||
| auto ids = IntToSize(idx + 1); | |||||
| // Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1. | |||||
| if (ids >= cnode->size()) { | |||||
| MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size(); | |||||
| } | |||||
| new_fg->set_output(cnode->input(ids)); | |||||
| } else { | |||||
| new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)})); | |||||
| } | |||||
| cache[idx] = new_fg_outer; | |||||
| } | |||||
| return cache[idx]; | |||||
| } | |||||
| private: | |||||
| std::unordered_map<FuncGraphPtr, std::unordered_map<int, FuncGraphPtr>> cache_; | |||||
| }; | |||||
| } // namespace internal | } // namespace internal | ||||
| // {prim::kPrimTupleGetItem, {G, Xs}, C} | // {prim::kPrimTupleGetItem, {G, Xs}, C} | ||||
| @@ -385,13 +431,199 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||||
| internal::GetitemTransform getitem_transform_; | internal::GetitemTransform getitem_transform_; | ||||
| }; | }; | ||||
| // {prim::kPrimTupleGetItem, {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, C} | |||||
| class IncorporateGetitemSwitchLayerA : public AnfVisitor { | |||||
| public: | |||||
| IncorporateGetitemSwitchLayerA() : getitem_transform_() {} | |||||
| ~IncorporateGetitemSwitchLayerA() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| is_in_get_ = true; | |||||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); | |||||
| is_in_get_ = false; | |||||
| auto fg = node->func_graph(); | |||||
| if (idx_ == -1 || switch_layer_ == nullptr || fg == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| is_in_switch_ = true; | |||||
| AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_); | |||||
| is_in_switch_ = false; | |||||
| if (graphs_.empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> layers; | |||||
| for (auto &graph : graphs_) { | |||||
| auto fg_transform = getitem_transform_(graph, idx_); | |||||
| if (fg_transform == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| layers.push_back(NewValueNode(fg_transform)); | |||||
| } | |||||
| auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); | |||||
| std::vector<AnfNodePtr> sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node}; | |||||
| auto sw_node = fg->NewCNode(sw_args); | |||||
| (void)args_.insert(args_.begin(), sw_node); | |||||
| return fg->NewCNode(args_); | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { | |||||
| if (is_in_switch_ && x_ == nullptr) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| AnfVisitor::Visit(node); | |||||
| } | |||||
| void Visit(const CNodePtr &cnode) override { | |||||
| if (is_in_get_ && cnode->size() != 0) { | |||||
| auto &inputs = cnode->inputs(); | |||||
| switch_layer_ = inputs[0]; | |||||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); | |||||
| } | |||||
| if (is_in_switch_ && cnode->size() > 2) { | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) { | |||||
| (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), | |||||
| [](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); }); | |||||
| } | |||||
| } | |||||
| } | |||||
| void Visit(const ValueNodePtr &vnode) override { | |||||
| if (is_in_get_) { | |||||
| idx_ = GetValue<int>(vnode->value()); | |||||
| } | |||||
| } | |||||
| void Reset() { | |||||
| x_ = nullptr; | |||||
| graphs_.clear(); | |||||
| switch_layer_ = nullptr; | |||||
| args_.clear(); | |||||
| is_in_get_ = false; | |||||
| is_in_switch_ = false; | |||||
| } | |||||
| private: | |||||
| int idx_{-1}; | |||||
| AnfNodePtr switch_layer_{nullptr}, x_{nullptr}; | |||||
| std::vector<FuncGraphPtr> graphs_{}; | |||||
| bool is_in_get_{false}, is_in_switch_{false}; | |||||
| std::vector<AnfNodePtr> args_{}; | |||||
| internal::GetitemTransform getitem_transform_; | |||||
| }; | |||||
| // {prim::kPrimTupleGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C} | |||||
| class IncorporateGetitemSwitchLayerB : public AnfVisitor { | |||||
| public: | |||||
| IncorporateGetitemSwitchLayerB() : getitem_transform_() {} | |||||
| ~IncorporateGetitemSwitchLayerB() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| is_in_get_ = true; | |||||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); | |||||
| is_in_get_ = false; | |||||
| auto fg = node->func_graph(); | |||||
| if (idx_ == -1 || switch_layer_call_ == nullptr || !switch_layer_call_->isa<CNode>() || fg == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto &switch_layer_call_inputs = switch_layer_call_->cast<CNodePtr>()->inputs(); | |||||
| (void)std::copy(switch_layer_call_inputs.begin() + 1, switch_layer_call_inputs.end(), std::back_inserter(args_)); | |||||
| is_in_switch_ = true; | |||||
| AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_call_inputs[0]); | |||||
| is_in_switch_ = false; | |||||
| if (graphs_.empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> layers; | |||||
| for (auto &graph : graphs_) { | |||||
| auto fg_transform = getitem_transform_(graph, idx_); | |||||
| if (fg_transform == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| layers.push_back(NewValueNode(fg_transform)); | |||||
| } | |||||
| auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); | |||||
| std::vector<AnfNodePtr> sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node}; | |||||
| auto sw_node = fg->NewCNode(sw_args); | |||||
| (void)args_.insert(args_.begin(), sw_node); | |||||
| auto call_switch_layer = fg->NewCNode(args_); | |||||
| (void)outer_call_args_.insert(outer_call_args_.begin(), call_switch_layer); | |||||
| return fg->NewCNode(outer_call_args_); | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { | |||||
| if (is_in_switch_ && x_ == nullptr) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| AnfVisitor::Visit(node); | |||||
| } | |||||
| void Visit(const CNodePtr &cnode) override { | |||||
| if (is_in_get_ && cnode->size() != 0) { | |||||
| auto &inputs = cnode->inputs(); | |||||
| switch_layer_call_ = inputs[0]; | |||||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_)); | |||||
| } | |||||
| if (is_in_switch_ && cnode->size() > 2) { | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) { | |||||
| (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), | |||||
| [](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); }); | |||||
| } | |||||
| } | |||||
| } | |||||
| void Visit(const ValueNodePtr &vnode) override { | |||||
| if (is_in_get_) { | |||||
| idx_ = GetValue<int>(vnode->value()); | |||||
| } | |||||
| } | |||||
| void Reset() { | |||||
| x_ = nullptr; | |||||
| graphs_.clear(); | |||||
| switch_layer_call_ = nullptr; | |||||
| args_.clear(); | |||||
| outer_call_args_.clear(); | |||||
| is_in_get_ = false; | |||||
| is_in_switch_ = false; | |||||
| } | |||||
| private: | |||||
| int idx_{-1}; | |||||
| AnfNodePtr switch_layer_call_{nullptr}, x_{nullptr}; | |||||
| std::vector<FuncGraphPtr> graphs_{}; | |||||
| bool is_in_get_{false}, is_in_switch_{false}; | |||||
| std::vector<AnfNodePtr> args_{}; | |||||
| std::vector<AnfNodePtr> outer_call_args_{}; | |||||
| internal::GetItemTransformACrossGraph getitem_transform_; | |||||
| }; | |||||
| class IncorporateGetitemSet : public OptimizerCaller { | class IncorporateGetitemSet : public OptimizerCaller { | ||||
| public: | public: | ||||
| IncorporateGetitemSet() | IncorporateGetitemSet() | ||||
| : incorporate_getitem_(std::make_shared<IncorporateGetitem>()), | : incorporate_getitem_(std::make_shared<IncorporateGetitem>()), | ||||
| incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) { | |||||
| incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()), | |||||
| incorporate_getitem_switch_layer_a_(std::make_shared<IncorporateGetitemSwitchLayerA>()), | |||||
| incorporate_getitem_switch_layer_b_(std::make_shared<IncorporateGetitemSwitchLayerB>()) { | |||||
| eliminaters_.emplace_back(incorporate_getitem_); | eliminaters_.emplace_back(incorporate_getitem_); | ||||
| eliminaters_.emplace_back(incorporate_getitem_switch_); | eliminaters_.emplace_back(incorporate_getitem_switch_); | ||||
| eliminaters_.emplace_back(incorporate_getitem_switch_layer_a_); | |||||
| eliminaters_.emplace_back(incorporate_getitem_switch_layer_b_); | |||||
| } | } | ||||
| ~IncorporateGetitemSet() = default; | ~IncorporateGetitemSet() = default; | ||||
| @@ -407,7 +639,8 @@ class IncorporateGetitemSet : public OptimizerCaller { | |||||
| } | } | ||||
| private: | private: | ||||
| OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; | |||||
| OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_, incorporate_getitem_switch_layer_a_, | |||||
| incorporate_getitem_switch_layer_b_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | std::vector<OptimizerCallerPtr> eliminaters_{}; | ||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| @@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | ||||
| irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | ||||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | ||||
| irpass.value_based_eliminate_}); | |||||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_}); | |||||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | opt::OptPassConfig b_2 = opt::OptPassConfig({ | ||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| irpass.make_ref_eliminate_, | irpass.make_ref_eliminate_, | ||||
| @@ -464,6 +464,36 @@ def test_switch_layer_with_single_prim(): | |||||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | ||||
| def test_switch_layer_env_eliminate(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.conv = nn.Conv2d(1, 1, 3, pad_mode='same') | |||||
| self.conv2 = nn.Conv2d(1, 1, 5, pad_mode='same') | |||||
| self.funs = (self.conv, self.conv2) | |||||
| def construct(self, x, index): | |||||
| x = self.funs[index](x) | |||||
| return x | |||||
| class NetGrad(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(NetGrad, self).__init__() | |||||
| self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False) | |||||
| self.net = net | |||||
| self.weights = ParameterTuple(self.net.trainable_params()) | |||||
| def construct(self, x, index): | |||||
| weights = self.weights | |||||
| grad = self.grad_op(self.net, weights)(x, index) | |||||
| return grad | |||||
| net = Net() | |||||
| net2 = NetGrad(net) | |||||
| x = Tensor(np.ones((3, 1, 12, 12)), ms.float32) | |||||
| i = Tensor(1, ms.int32) | |||||
| net2(x, i) | |||||
| def test_control_depend_check(): | def test_control_depend_check(): | ||||
| with pytest.raises(TypeError) as e: | with pytest.raises(TypeError) as e: | ||||
| P.ControlDepend(0.0) | P.ControlDepend(0.0) | ||||