Browse Source

!22941 parallel_sparse_attention_ops_fix_repeated_cal

Merge pull request !22941 from yao_yf/parallel_sparse_attention_ops_fix_repeated_cal
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
d37d57d0fd
4 changed files with 28 additions and 6 deletions
  1. +7
    -6
      mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.h
  3. +2
    -0
      mindspore/train/model.py
  4. +18
    -0
      tests/ut/python/parallel/test_cus_matmul_dds.py

+ 7
- 6
mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.cc View File

@@ -117,6 +117,7 @@ Status MatmulDDSInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(1);
dev_matrix_shape_.push_back(1);
dev_matrix_shape_.push_back(1);
dev_matrix_shape_origin_ = dev_matrix_shape_;
return SUCCESS;
}

@@ -168,22 +169,22 @@ Status MatmulDDSInfo::InferTensorMap() {
}
TensorMap output_tensor_map_local_prob;
// output_tensor_map_local_prob [5, 6, -1, -1, -1, -1, -1]
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) {
for (size_t i = 0; i < dev_matrix_shape_origin_.size(); ++i) {
if (i == 0) {
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2));
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 2));
} else if (i == 1) {
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1));
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 1));
} else {
output_tensor_map_local_prob.push_back((int64_t)(MAP_NONE));
}
}
TensorMap output_tensor_map_global_prob;
// output_tensor_map_global_prob [5, 6, -1, -1, -1, -1, -1]
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) {
for (size_t i = 0; i < dev_matrix_shape_origin_.size(); ++i) {
if (i == 0) {
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2));
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 2));
} else if (i == 1) {
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1));
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_origin_.size() - 1));
} else {
output_tensor_map_global_prob.push_back((int64_t)(MAP_NONE));
}


+ 1
- 0
mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.h View File

@@ -55,6 +55,7 @@ class MatmulDDSInfo : public OperatorInfo {

private:
Dimensions input_strategy_;
Shape dev_matrix_shape_origin_;
int64_t batch_size_ = 0;
int64_t num_heads_ = 0;
};


+ 2
- 0
mindspore/train/model.py View File

@@ -233,6 +233,8 @@ class Model:
acc_level=self._acc_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None



+ 18
- 0
tests/ut/python/parallel/test_cus_matmul_dds.py View File

@@ -167,3 +167,21 @@ def test_cus_matmul_dds_model_parallel_auto():
dp = 1
mp = 16
compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)

def test_cus_matmul_dds_repeat_cal_auto():
set_algo_parameters(fully_use_devices=False)
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 1
mp = 2
compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)

def test_cus_matmul_dds_repeat1_cal_auto():
set_algo_parameters(fully_use_devices=False)
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 2
mp = 1
compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)

Loading…
Cancel
Save