|
|
|
@@ -130,6 +130,7 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode |
|
|
|
if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { |
|
|
|
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; |
|
|
|
} |
|
|
|
std::unordered_map<AnfNodePtr, FuncGraphPtr> node_to_fg; |
|
|
|
auto tuple_graphs = input->cast<CNodePtr>(); |
|
|
|
for (size_t i = 1; i < tuple_graphs->size(); ++i) { |
|
|
|
auto graph = tuple_graphs->input(i); |
|
|
|
@@ -145,11 +146,19 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode |
|
|
|
} |
|
|
|
// Consider direct and indirect fvs. |
|
|
|
for (auto fv : func_graph->free_variables_nodes()) { |
|
|
|
if (node_to_fg.find(fv) != node_to_fg.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
node_to_fg[fv] = func_graph; |
|
|
|
BackPropagateFv(fv, env); |
|
|
|
} |
|
|
|
for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { |
|
|
|
MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " |
|
|
|
<< indirect_fv.first->ToString() << "."; |
|
|
|
if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
node_to_fg[indirect_fv.first] = func_graph; |
|
|
|
BackPropagateFv(indirect_fv.first, env); |
|
|
|
} |
|
|
|
} |
|
|
|
|