Browse Source

!2273 [AutoParallel]update EmbeddingLookUp op

Merge pull request !2273 from lichen/fix_embeding_look_up
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c0fe8c0322
7 changed files with 54 additions and 27 deletions
  1. +6
    -1
      mindspore/ccsrc/parallel/graph_util/generate_graph.cc
  2. +39
    -19
      mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
  3. +1
    -0
      mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
  4. +1
    -0
      mindspore/ccsrc/parallel/ops_info/ops_utils.h
  5. +1
    -1
      mindspore/ccsrc/parallel/step_parallel.cc
  6. +3
    -3
      tests/ut/python/parallel/test_gather_v2.py
  7. +3
    -3
      tests/ut/python/parallel/test_sparse_gather_v2.py

+ 6
- 1
mindspore/ccsrc/parallel/graph_util/generate_graph.cc View File

@@ -28,9 +28,14 @@ namespace parallel {
std::string GetOpPythonPath(const OperatorName &op_name) { std::string GetOpPythonPath(const OperatorName &op_name) {
// almost all ops are defined in two main paths // almost all ops are defined in two main paths
const std::string ops_module = OP_PATH; const std::string ops_module = OP_PATH;
const std::string inner_ops_module = INNER_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module)); py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
if (!py::hasattr(mod, common::SafeCStr(op_name))) { if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " don't have op:" << op_name;
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
}
return inner_ops_module;
} }
return ops_module; return ops_module;
} }


+ 39
- 19
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc View File

@@ -56,6 +56,12 @@ Status GatherV2PInfo::GetAttrs() {
} }
} }


// target=CPU, axis must be 0
if (target_ == "CPU" && axis_ != 0) {
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
return FAILED;
}

return SUCCESS; return SUCCESS;
} }


@@ -279,6 +285,11 @@ Status GatherV2PInfo::InferBias() {
int32_t rank = g_device_manager->global_rank(); int32_t rank = g_device_manager->global_rank();
auto input_shape = inputs_shape_.at(0); auto input_shape = inputs_shape_.at(0);
auto params_strategy = strategy_->GetInputDim().at(0); auto params_strategy = strategy_->GetInputDim().at(0);
// axis don't split
if (params_strategy.at(axis_) == 1) {
bias_ = 0;
return SUCCESS;
}
// params_size=1, axis=0 // params_size=1, axis=0
if ((input_shape.size() == 1) && (axis_ == 0)) { if ((input_shape.size() == 1) && (axis_ == 0)) {
slice_size_ = input_shape.at(0) / params_strategy.at(0); slice_size_ = input_shape.at(0) / params_strategy.at(0);
@@ -353,26 +364,35 @@ Status GatherV2PInfo::InferForwardCommunication() {
} }
auto group_size = group_.GetDevNum(); auto group_size = group_.GetDevNum();
Attr attr_group; Attr attr_group;
// group size <= 8
std::vector<int32_t> rank_list;
if (group_size <= 8) {
reduce_scatter_flag_ = false;
operator_name = HOST_REDUCE_SCATTER;
rank_list = GetRankFromGroup(group_);
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
if (host_reduce_scatter_) {
// group size <= 8
std::vector<int32_t> rank_list;
if (group_size <= 8) {
reduce_scatter_flag_ = false;
operator_name = HOST_REDUCE_SCATTER;
rank_list = GetRankFromGroup(group_);
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
} else {
// group size > 8, don't support host reduce_scatter
reduce_scatter_flag_ = true;
split_num_ = SizeToInt(group_size / 8);
CheckGlobalDeviceManager();
operator_name = REDUCE_SCATTER;
int32_t rank = g_device_manager->global_rank();
size_t repeat = group_size / 8;
for (size_t i = 0; i < repeat; ++i) {
rank_list.push_back(rank + SizeToInt(i * 8));
}
Group g = g_device_manager->CreateGroup(rank_list);
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
}
} else { } else {
// group size > 8
reduce_scatter_flag_ = true;
split_num_ = SizeToInt(group_size / 8);
CheckGlobalDeviceManager();
operator_name = REDUCE_SCATTER; operator_name = REDUCE_SCATTER;
int32_t rank = g_device_manager->global_rank();
size_t repeat = group_size / 8;
for (size_t i = 0; i < repeat; ++i) {
rank_list.push_back(rank + SizeToInt(i * 8));
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED;
} }
Group g = g_device_manager->CreateGroup(rank_list);
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
} }
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
OperatorAttrs attrs = {attr_op, attr_group}; OperatorAttrs attrs = {attr_op, attr_group};
@@ -446,8 +466,8 @@ Status GatherV2PInfo::ComputeReplaceOp() {
Attr param_offset = std::make_pair("offset", MakeValue(bias_)); Attr param_offset = std::make_pair("offset", MakeValue(bias_));
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5),
std::make_pair(param_split_num, 6)};
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
std::make_pair(param_split_num, 5)};
OperatorArgs args = std::make_pair(attrs, params); OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(op_name, args); Operator op = std::make_pair(op_name, args);
replace_op_.push_back(op); replace_op_.push_back(op);


+ 1
- 0
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h View File

@@ -70,6 +70,7 @@ class GatherV2PInfo : public OperatorInfo {
Group group_; Group group_;
bool reduce_scatter_flag_ = false; bool reduce_scatter_flag_ = false;
int32_t split_num_ = 1; int32_t split_num_ = 1;
bool host_reduce_scatter_ = false;
}; };


class SparseGatherV2Info : public GatherV2PInfo { class SparseGatherV2Info : public GatherV2PInfo {


+ 1
- 0
mindspore/ccsrc/parallel/ops_info/ops_utils.h View File

@@ -55,6 +55,7 @@ constexpr char REDUCE_OP_SUM[] = "sum";
constexpr char REDUCE_OP_MAX[] = "max"; constexpr char REDUCE_OP_MAX[] = "max";
constexpr char REDUCE_OP_MIN[] = "min"; constexpr char REDUCE_OP_MIN[] = "min";
constexpr char OP_PATH[] = "mindspore.ops.operations"; constexpr char OP_PATH[] = "mindspore.ops.operations";
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
constexpr char GET_OP_FUNCTION[] = "_get_python_op"; constexpr char GET_OP_FUNCTION[] = "_get_python_op";
constexpr char KEEP_DIMS[] = "keep_dims"; constexpr char KEEP_DIMS[] = "keep_dims";


+ 1
- 1
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
auto prim = GetValueNode<PrimitivePtr>(node->input(0)); auto prim = GetValueNode<PrimitivePtr>(node->input(0));
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)};
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
} }
if (!params.empty()) { if (!params.empty()) {
Param param_first = *(params.begin()); Param param_first = *(params.begin());


+ 3
- 3
tests/ut/python/parallel/test_gather_v2.py View File

@@ -184,7 +184,7 @@ def test_gatherv2_auto1():
_executor.compile(net, x, y) _executor.compile(net, x, y)




def need_fix_test_gatherv2_cpu0():
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1)) strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
@@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0():
_executor.compile(net, x, y) _executor.compile(net, x, y)




def need_fix_test_gatherv2_cpu1():
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1)) strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
@@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1():
_executor.compile(net, x, y) _executor.compile(net, x, y)




def need_fix_test_gatherv2_cpu2():
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1)) strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))


+ 3
- 3
tests/ut/python/parallel/test_sparse_gather_v2.py View File

@@ -184,7 +184,7 @@ def test_gatherv2_auto1():
_executor.compile(net, x, y) _executor.compile(net, x, y)




def need_fix_test_gatherv2_cpu0():
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1)) strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
@@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0():
_executor.compile(net, x, y) _executor.compile(net, x, y)




def need_fix_test_gatherv2_cpu1():
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1)) strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))
@@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1():
_executor.compile(net, x, y) _executor.compile(net, x, y)




def need_fix_test_gatherv2_cpu2():
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1)) strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1)) strategy2 = ((4, 2, 1), (4, 2, 1))


Loading…
Cancel
Save