Browse Source

!12663 [auto-monad] Fix backend control flow bug found by igamma test

From: @hwhewei
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
10692e9cd7
3 changed files with 44 additions and 0 deletions
  1. +3
    -0
      mindspore/ccsrc/backend/session/ascend_auto_monad.cc
  2. +19
    -0
      mindspore/ccsrc/backend/session/kernel_graph.cc
  3. +22
    -0
      tests/st/auto_monad/test_effect_ops.py

+ 3
- 0
mindspore/ccsrc/backend/session/ascend_auto_monad.cc View File

@@ -305,6 +305,9 @@ class AscendAutoMonadConverter {
} }
if (return_label_ != kNoLabel) { if (return_label_ != kNoLabel) {
(void)LabelGoto(return_label_); (void)LabelGoto(return_label_);
} else {
// Clear end goto if return label not set.
kernel_graph_->set_end_goto(nullptr);
} }
} }
} }


+ 19
- 0
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -274,6 +274,25 @@ std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
continue; continue;
} }


//
// Re-order:
// u = LabelGoto(...)
// x = Mul(...)
// LabelSet(u)
// To:
// u = LabelGoto(...)
// LabelSet(u)
// x = Mul(...)
// This prevent Mul be skipped.
//
if (IsPrimitiveCNode(node, prim::kPrimLabelSet) && (re_order.back() != node->input(1))) {
auto iter = std::find(re_order.rbegin() + 1, re_order.rend(), node->input(1));
if (iter != re_order.rend()) {
re_order.insert(iter.base(), node);
continue;
}
}

re_order.push_back(node); re_order.push_back(node);
} }
if (end_goto_ != nullptr) { if (end_goto_ != nullptr) {


+ 22
- 0
tests/st/auto_monad/test_effect_ops.py View File

@@ -15,6 +15,7 @@
import os import os
import tempfile import tempfile
import pytest import pytest
import scipy
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.operations as P import mindspore.ops.operations as P
@@ -395,3 +396,24 @@ def test_summary():
event = summary_writer.read_event() event = summary_writer.read_event()
tags = set(value.tag for value in event.summary.value) tags = set(value.tag for value in event.summary.value)
assert tags == {'tensor', 'histogram', 'scalar', 'image'} assert tags == {'tensor', 'histogram', 'scalar', 'image'}


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_igamma():
class IGammaTest(nn.Cell):
def __init__(self):
super().__init__()
self.igamma = nn.IGamma()

def construct(self, x, a):
return self.igamma(a=a, x=x)

x = 4.22
a = 2.29
net = IGammaTest()
out = net(Tensor(x, mstype.float32), Tensor(a, mstype.float32))
expect = scipy.special.gammainc(a, x)
assert np.allclose(out.asnumpy(), expect, rtol=1e-5, atol=1e-5, equal_nan=True)

Loading…
Cancel
Save