Browse Source

!8642 Keep debug info. and trace info. after Grad Operation.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
610f06b92d
2 changed files with 23 additions and 14 deletions
  1. +23
    -12
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
  2. +0
    -2
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h

+ 23
- 12
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc View File

@@ -227,17 +227,25 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
std::vector<AdjointPtr> param_adjoints; std::vector<AdjointPtr> param_adjoints;
for (size_t i = 0; i < cnode_morph->size(); i++) { for (size_t i = 0; i < cnode_morph->size(); i++) {
auto node = cnode_morph->input(i); auto node = cnode_morph->input(i);
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
AdjointPtr node_adjoint = nullptr; AdjointPtr node_adjoint = nullptr;
AnfNodePtr k = nullptr; AnfNodePtr k = nullptr;
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
node_adjoint = node_adjoint_iter->second;
if (IsValueNode<Primitive>(node)) {
TraceManager::DebugTrace(std::make_shared<TraceEquiv>(cnode_morph->debug_info()));
k = MapToK(node);
TraceManager::EndTrace();
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
anfnode_to_adjoin_[node] = node_adjoint;
} else { } else {
// Input might be a CNode that needs to be handled before hand.
node_adjoint = MapMorphism(node);
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
node_adjoint = node_adjoint_iter->second;
} else {
// Input might be a CNode that needs to be handled previously.
node_adjoint = MapMorphism(node);
}
MS_EXCEPTION_IF_NULL(node_adjoint);
k = node_adjoint->k();
} }
MS_EXCEPTION_IF_NULL(node_adjoint);
k = node_adjoint->k();
if (k == nullptr) { if (k == nullptr) {
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
} }
@@ -270,6 +278,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << "."; MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
return node_adjoint; return node_adjoint;
} }

void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) { void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) {
MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>(); MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>();
if (value->isa<tensor::Tensor>()) { if (value->isa<tensor::Tensor>()) {
@@ -560,7 +569,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
return nullptr; return nullptr;
} }


// MapToK(func)
// Map func graph to K
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
auto f = func_graph_to_functor_.find(primal); auto f = func_graph_to_functor_.find(primal);
if (f != func_graph_to_functor_.end()) { if (f != func_graph_to_functor_.end()) {
@@ -586,7 +595,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
// Construct representation graph for given node. // Construct representation graph for given node.
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
ScopeGuard scope_guard(primal->scope()); ScopeGuard scope_guard(primal->scope());
// MapToK(prim)
// Map primitive to K
if (IsValueNode<Primitive>(primal)) { if (IsValueNode<Primitive>(primal)) {
auto value_node = primal->cast<ValueNodePtr>(); auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node); auto prim = GetValueNode<PrimitivePtr>(value_node);
@@ -605,7 +614,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
} }
} }


// MapToK(func)
// Map func graph to K
if (IsValueNode<FuncGraph>(primal)) { if (IsValueNode<FuncGraph>(primal)) {
auto func_graph = GetValueNode<FuncGraphPtr>(primal); auto func_graph = GetValueNode<FuncGraphPtr>(primal);
auto k_func = MapToK(func_graph); auto k_func = MapToK(func_graph);
@@ -681,7 +690,7 @@ void DFunctor::MapValueObject() {
anfnode_to_adjoin_[node] = adjoint; anfnode_to_adjoin_[node] = adjoint;
continue; continue;
} }
// Skip Return.
// Skip Primitive.
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) { if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
continue; continue;
} }
@@ -796,12 +805,14 @@ void DFunctor::EliminatePrimalGraph() {
auto index = it.first->second; auto index = it.first->second;
auto vnode = cnode->inputs()[index]; auto vnode = cnode->inputs()[index];
if (index != 0) { if (index != 0) {
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
MS_LOG(DEBUG) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
continue; continue;
} }
cnode->set_input(0, k_vnode); // Replace primal graph with k graph cnode->set_input(0, k_vnode); // Replace primal graph with k graph
auto construct_wrapper = cnode->func_graph(); auto construct_wrapper = cnode->func_graph();
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode->debug_info()));
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0}); auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
TraceManager::EndTrace();
manager->Replace(cnode, getitem0); manager->Replace(cnode, getitem0);
} }
} }


+ 0
- 2
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h View File

@@ -194,10 +194,8 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
std::vector<AnfNodePtr> transf_args; std::vector<AnfNodePtr> transf_args;
TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); TransformArgs(mng, cloned_bprop_fg, outer, &transf_args);


TraceManager::DebugTrace(std::make_shared<TraceEquiv>(dout->debug_info()));
(void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); (void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
auto out_value = outer->NewCNode(transf_args); auto out_value = outer->NewCNode(transf_args);
TraceManager::EndTrace();


(void)mng->Replace(out_param, out_value); (void)mng->Replace(out_param, out_value);




Loading…
Cancel
Save