|
|
|
@@ -37,7 +37,7 @@ namespace ad { |
|
|
|
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_; |
|
|
|
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_; |
|
|
|
|
|
|
|
int lift_fv_before_grad = -1; |
|
|
|
bool lift_fv_before_grad = true; |
|
|
|
|
|
|
|
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) |
|
|
|
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { |
|
|
|
@@ -76,7 +76,7 @@ void DFunctor::Clear() { |
|
|
|
|
|
|
|
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { |
|
|
|
MS_EXCEPTION_IF_NULL(fv); |
|
|
|
if (lift_fv_before_grad == 1) { |
|
|
|
if (lift_fv_before_grad) { |
|
|
|
MS_EXCEPTION_IF_NULL(fv->func_graph()); |
|
|
|
MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv:" |
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << "."; |
|
|
|
@@ -446,7 +446,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { |
|
|
|
// Add grads wrt fv. |
|
|
|
const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); |
|
|
|
if (!is_top_ && free_variables_nodes.size() != 0) { |
|
|
|
if (lift_fv_before_grad == 1) { |
|
|
|
if (lift_fv_before_grad) { |
|
|
|
MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString() |
|
|
|
<< "."; |
|
|
|
} |
|
|
|
@@ -475,7 +475,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { |
|
|
|
if (lift_fv_before_grad == 1) { |
|
|
|
if (lift_fv_before_grad) { |
|
|
|
MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv " |
|
|
|
<< grad_fv->ToString() << " " << primal_graph_->ToString() << "."; |
|
|
|
} |
|
|
|
@@ -517,7 +517,7 @@ void DFunctor::MapMorphism() { |
|
|
|
|
|
|
|
// Set output for tape closure. |
|
|
|
AnfNodePtr grad_fv; |
|
|
|
if (lift_fv_before_grad == 1) { |
|
|
|
if (lift_fv_before_grad) { |
|
|
|
grad_fv = AttachFvDoutToTape(NewValueNode(newenv)); |
|
|
|
} else { |
|
|
|
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); |
|
|
|
|