|
|
|
@@ -16,9 +16,9 @@ |
|
|
|
|
|
|
|
#include "frontend/optimizer/ad/dfunctor.h" |
|
|
|
|
|
|
|
#include <map> |
|
|
|
#include <memory> |
|
|
|
#include <string> |
|
|
|
#include <utility> |
|
|
|
|
|
|
|
#include "ir/anf.h" |
|
|
|
#include "utils/info.h" |
|
|
|
@@ -99,14 +99,23 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { |
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); |
|
|
|
} |
|
|
|
} |
|
|
|
auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); |
|
|
|
fv_adjoint->second->RegisterKUser(node, 1); |
|
|
|
auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()}); |
|
|
|
fv_adjoint->second->RegisterKUser(default_val, 1); |
|
|
|
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, node, default_val}); |
|
|
|
auto fv_node = fv_adjoint->second->k(); |
|
|
|
auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node); |
|
|
|
CNodePtr embed_node, default_val_node; |
|
|
|
if (cached_envitem_iter != anfnode_to_envitem_.end()) { |
|
|
|
embed_node = cached_envitem_iter->second.first; |
|
|
|
default_val_node = cached_envitem_iter->second.second; |
|
|
|
} else { |
|
|
|
embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node}); |
|
|
|
default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node}); |
|
|
|
fv_adjoint->second->RegisterKUser(embed_node, 1); |
|
|
|
fv_adjoint->second->RegisterKUser(default_val_node, 1); |
|
|
|
anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node); |
|
|
|
} |
|
|
|
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node}); |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv " |
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << "."; |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << node->ToString() << "."; |
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << "."; |
|
|
|
fv_adjoint->second->AccumulateDout(dfv); |
|
|
|
} |
|
|
|
|
|
|
|
|