|
|
|
@@ -91,21 +91,32 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { |
|
|
|
if (fv_adjoint == anfnode_to_adjoin_.end()) { |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() |
|
|
|
<< " " << fv->ToString() << "."; |
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); |
|
|
|
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " |
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << "."; |
|
|
|
auto parent_adjoint = FindAdjoint(fv); |
|
|
|
AdjointPtr adjoint = nullptr; |
|
|
|
if (parent_adjoint != nullptr) { |
|
|
|
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " |
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << "."; |
|
|
|
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_); |
|
|
|
|
|
|
|
if (fv->func_graph() == primal_graph_) { |
|
|
|
// If this fv is not mapped by MapMorphism because of cnode order, then map it now. |
|
|
|
(void)MapMorphism(fv); |
|
|
|
fv_adjoint = anfnode_to_adjoin_.find(fv); |
|
|
|
if (fv_adjoint == anfnode_to_adjoin_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " " |
|
|
|
<< fv->ToString() << "."; |
|
|
|
} |
|
|
|
anfnode_to_adjoin_indirect_fv_[fv] = adjoint; |
|
|
|
} else { |
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); |
|
|
|
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " |
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << "."; |
|
|
|
auto parent_adjoint = FindAdjoint(fv); |
|
|
|
AdjointPtr adjoint = nullptr; |
|
|
|
if (parent_adjoint != nullptr) { |
|
|
|
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " |
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << "."; |
|
|
|
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_); |
|
|
|
} |
|
|
|
anfnode_to_adjoin_indirect_fv_[fv] = adjoint; |
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
auto fv_node = fv_adjoint->second->k(); |
|
|
|
|