|
|
|
@@ -227,17 +227,25 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { |
|
|
|
std::vector<AdjointPtr> param_adjoints; |
|
|
|
for (size_t i = 0; i < cnode_morph->size(); i++) { |
|
|
|
auto node = cnode_morph->input(i); |
|
|
|
auto node_adjoint_iter = anfnode_to_adjoin_.find(node); |
|
|
|
AdjointPtr node_adjoint = nullptr; |
|
|
|
AnfNodePtr k = nullptr; |
|
|
|
if (node_adjoint_iter != anfnode_to_adjoin_.end()) { |
|
|
|
node_adjoint = node_adjoint_iter->second; |
|
|
|
if (IsValueNode<Primitive>(node)) { |
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceEquiv>(cnode_morph->debug_info())); |
|
|
|
k = MapToK(node); |
|
|
|
TraceManager::EndTrace(); |
|
|
|
node_adjoint = std::make_shared<Adjoint>(node, k, tape_); |
|
|
|
anfnode_to_adjoin_[node] = node_adjoint; |
|
|
|
} else { |
|
|
|
// Input might be a CNode that needs to be handled before hand. |
|
|
|
node_adjoint = MapMorphism(node); |
|
|
|
auto node_adjoint_iter = anfnode_to_adjoin_.find(node); |
|
|
|
if (node_adjoint_iter != anfnode_to_adjoin_.end()) { |
|
|
|
node_adjoint = node_adjoint_iter->second; |
|
|
|
} else { |
|
|
|
// Input might be a CNode that needs to be handled previously. |
|
|
|
node_adjoint = MapMorphism(node); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(node_adjoint); |
|
|
|
k = node_adjoint->k(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(node_adjoint); |
|
|
|
k = node_adjoint->k(); |
|
|
|
if (k == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; |
|
|
|
} |
|
|
|
@@ -270,6 +278,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { |
|
|
|
MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << "."; |
|
|
|
return node_adjoint; |
|
|
|
} |
|
|
|
|
|
|
|
void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) { |
|
|
|
MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>(); |
|
|
|
if (value->isa<tensor::Tensor>()) { |
|
|
|
@@ -560,7 +569,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// MapToK(func) |
|
|
|
// Map func graph to K |
|
|
|
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { |
|
|
|
auto f = func_graph_to_functor_.find(primal); |
|
|
|
if (f != func_graph_to_functor_.end()) { |
|
|
|
@@ -586,7 +595,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { |
|
|
|
// Construct representation graph for given node. |
|
|
|
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { |
|
|
|
ScopeGuard scope_guard(primal->scope()); |
|
|
|
// MapToK(prim) |
|
|
|
// Map primitive to K |
|
|
|
if (IsValueNode<Primitive>(primal)) { |
|
|
|
auto value_node = primal->cast<ValueNodePtr>(); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node); |
|
|
|
@@ -605,7 +614,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// MapToK(func) |
|
|
|
// Map func graph to K |
|
|
|
if (IsValueNode<FuncGraph>(primal)) { |
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(primal); |
|
|
|
auto k_func = MapToK(func_graph); |
|
|
|
@@ -681,7 +690,7 @@ void DFunctor::MapValueObject() { |
|
|
|
anfnode_to_adjoin_[node] = adjoint; |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Skip Return. |
|
|
|
// Skip Primitive. |
|
|
|
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -796,12 +805,14 @@ void DFunctor::EliminatePrimalGraph() { |
|
|
|
auto index = it.first->second; |
|
|
|
auto vnode = cnode->inputs()[index]; |
|
|
|
if (index != 0) { |
|
|
|
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}"; |
|
|
|
MS_LOG(DEBUG) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
cnode->set_input(0, k_vnode); // Replace primal graph with k graph |
|
|
|
auto construct_wrapper = cnode->func_graph(); |
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode->debug_info())); |
|
|
|
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0}); |
|
|
|
TraceManager::EndTrace(); |
|
|
|
manager->Replace(cnode, getitem0); |
|
|
|
} |
|
|
|
} |
|
|
|
|