Browse Source

fix bug of split optimizer.

tags/v1.2.0-rc1
liu_xiao_93 4 years ago
parent
commit
2efd12dec6
1 changed files with 8 additions and 5 deletions
  1. +8
    -5
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_n_optimizer.cc

+ 8
- 5
mindspore/ccsrc/backend/optimizer/ascend/enhancer/split_n_optimizer.cc View File

@@ -90,6 +90,7 @@ KernelWithIndex VisitSplitKernel(const AnfNodePtr &anf_node, size_t index) {
bool InputCheck(const AnfNodePtr &node) { bool InputCheck(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto in_nums = AnfAlgo::GetInputTensorNum(node); auto in_nums = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < in_nums; i++) { for (size_t i = 0; i < in_nums; i++) {
auto in_node = VisitSplitKernel(AnfAlgo::GetInputNode(cnode, i), 0).first; auto in_node = VisitSplitKernel(AnfAlgo::GetInputNode(cnode, i), 0).first;
@@ -98,7 +99,9 @@ bool InputCheck(const AnfNodePtr &node) {
return false; return false;
} }
if (in_node->isa<CNode>()) { if (in_node->isa<CNode>()) {
auto in_node_name = AnfAlgo::GetCNodeName(in_node->cast<CNodePtr>());
auto in_cnode = in_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(in_cnode);
auto in_node_name = AnfAlgo::GetCNodeName(in_cnode);
auto trans_input = AnfAlgo::VisitKernel(in_node, 0).first; auto trans_input = AnfAlgo::VisitKernel(in_node, 0).first;
if (in_node_name == kTransDataOpName && (trans_input->isa<Parameter>() || trans_input->isa<ValueNode>())) { if (in_node_name == kTransDataOpName && (trans_input->isa<Parameter>() || trans_input->isa<ValueNode>())) {
MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; MS_LOG(INFO) << "Data->TransData->split, can not optimizer.";
@@ -107,9 +110,9 @@ bool InputCheck(const AnfNodePtr &node) {
if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) { if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) {
return false; return false;
} }
if ((AnfAlgo::HasNodeAttr("non_task", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "non_task")) ||
(AnfAlgo::HasNodeAttr("nop_node", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "nop_node"))) {
MS_LOG(INFO) << "Input has non_task or nop_node attr, can not optimizer.";
if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) ||
opt::IsNopNode(in_cnode)) {
MS_LOG(INFO) << "Input is nop node or has non_task attr, can not optimizer.";
return false; return false;
} }
} }
@@ -140,7 +143,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
return false; return false;
} }
auto op_name = AnfAlgo ::GetCNodeName(item); auto op_name = AnfAlgo ::GetCNodeName(item);
if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(node)) {
if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(item)) {
MS_LOG(INFO) << "Next node is " << item->fullname_with_scope() << ", not a invalid node, can not optimizer."; MS_LOG(INFO) << "Next node is " << item->fullname_with_scope() << ", not a invalid node, can not optimizer.";
return false; return false;
} }


Loading…
Cancel
Save