Browse Source

fix grpah mode loop sink bug in auto parallel

tags/v0.2.0-alpha
lichenever 6 years ago
parent
commit
f946aea10d
5 changed files with 17 additions and 12 deletions
  1. +2
    -2
      mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
  2. +4
    -0
      mindspore/ccsrc/parallel/step_parallel.cc
  3. +1
    -1
      mindspore/train/dataset_helper.py
  4. +6
    -6
      tests/ut/python/parallel/test_auto_parallel_parameter_cast.py
  5. +4
    -3
      tests/ut/python/parallel/test_auto_parallel_two_matmul.py

+ 2
- 2
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h View File

@@ -34,8 +34,8 @@ namespace parallel {
#define OPERATOR_TO_OPERATOR_CONNECTOR "-"
#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
#define DEFAULT_COST_MODEL_ALPHA 1.0
#define DEFAULT_COST_MODEL_BETA 65.0
#define DEFAULT_COST_MODEL_GAMMA 0.02
#define DEFAULT_COST_MODEL_BETA 260.0
#define DEFAULT_COST_MODEL_GAMMA 0.001
#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0


+ 4
- 0
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -375,6 +375,10 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
return false;
}
// get_next is not in the forward graph, we need mark the get_next as the forward node
if (prim->name() == GET_NEXT) {
return true;
}
if ((prim->name() == CAST)) {
if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) {
return false;


+ 1
- 1
mindspore/train/dataset_helper.py View File

@@ -88,7 +88,7 @@ class _DatasetIter:
# times the batch dimension of tensors for run
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
device_num = _get_device_num()
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)

def __iter__(self):
self.ind = 0


+ 6
- 6
tests/ut/python/parallel/test_auto_parallel_parameter_cast.py View File

@@ -80,9 +80,9 @@ def test_common_parameter():

_executor.compile(net, x, y, z, w, phase='train')
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op9': [[1, 1], [1, 8]],
'Default/network-Net/Cast-op10': [[1, 8]],
'Default/network-Net/MatMul-op0': [[1, 1], [1, 8]],
'Default/network-Net/Cast-op11': [[1, 8]]}
assert strategies == expected_strategies
expected_strategies = {'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op8': [[8, 1], [1, 1]],
'Default/network-Net/Cast-op7': [[1, 1]],
'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]],
'Default/network-Net/Cast-op9': [[1, 1]]}
assert strategies == expected_strategies

+ 4
- 3
tests/ut/python/parallel/test_auto_parallel_two_matmul.py View File

@@ -86,9 +86,9 @@ def test_two_matmul():
costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha")
assert costmodel_alpha == 1.0
costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta")
assert costmodel_beta == 65.0
assert costmodel_beta == 260.0
costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma")
assert costmodel_gamma == 0.02
assert costmodel_gamma == 0.001
costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold")
assert costmodel_communi_threshold == 2048.0
costmodel_communi_const = cost_model_context.get_cost_model_context("costmodel_communi_const")
@@ -137,4 +137,5 @@ def test_two_matmul():
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]}
assert strategies == expected_strategies
assert strategies == expected_strategies


Loading…
Cancel
Save