Browse Source

[auto-monad] Revert "Change backend execution order sorting policy"

This reverts commit 141c39b71c.
tags/v1.2.0-rc1
He Wei 4 years ago
parent
commit
3c44e731a2
4 changed files with 17 additions and 13 deletions
  1. +13
    -6
      mindspore/ccsrc/backend/session/kernel_graph.cc
  2. +1
    -4
      tests/st/auto_monad/test_auto_monad.py
  3. +1
    -1
      tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py
  4. +2
    -2
      tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py

+ 13
- 6
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -201,17 +201,21 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
}

void KernelGraph::SetExecOrderByDefault() {
std::queue<AnfNodePtr> zero_input_nodes;
UpdateNodeEdgeList(&zero_input_nodes);
std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
execution_order_.clear();
std::unordered_set<AnfNodePtr> visited_nodes;
std::queue<AnfNodePtr> zero_input_nodes;
AnfNodePtr last_communication_node = nullptr;
std::queue<AnfNodePtr> communication_descendants;
while (!zero_input_nodes.empty() || last_communication_node != nullptr) {
while (!seed_nodes.empty() || last_communication_node != nullptr) {
// seed nodes first, then visit last all reduce node descendant
if (last_communication_node != nullptr) {
if (seed_nodes.empty()) {
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
last_communication_node = nullptr;
} else {
zero_input_nodes.push(seed_nodes.front());
seed_nodes.pop();
}
// all reduce node descendant first, then common queue
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
@@ -900,11 +904,14 @@ void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
seed_nodes->push(node);
continue;
}
auto cnode = node->cast<CNodePtr>();
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
continue;
}
for (auto &input : cnode->inputs()) {
auto &inputs = cnode->inputs();
// We push inputs from right to left, so that them can be evaluated from left to right.
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
auto &input = *iter;
PushNoVisitedNode(input, &que, &visited_nodes);
AddDependEdge(node, input, 1);
}


+ 1
- 4
tests/st/auto_monad/test_auto_monad.py View File

@@ -1429,10 +1429,7 @@ def test_if_cast():
np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy())


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.skip(reason="not supported yet")
def test_multi_add_assign():
class Net(Cell):
def __init__(self, i1):


+ 1
- 1
tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py View File

@@ -229,7 +229,7 @@ def test_bert_performance():

# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
expect_loss_value = [11.3246, 11.2834, 11.2833]
expect_loss_value = [11.3660, 11.3265, 11.3264]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)



+ 2
- 2
tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py View File

@@ -229,8 +229,8 @@ def test_bert_precision(enable_graph_kernel=False):
expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565,
12.185522, 12.386192]
else:
expect_loss_value = [12.206587, 11.940709, 11.930911, 11.937369, 11.932178, 12.556069, 12.130172, 12.783402,
12.359581, 12.578078]
expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656,
12.407923, 12.631133]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)



Loading…
Cancel
Save