Merge pull request !2273 from lichen/fix_embeding_look_uptags/v0.5.0-beta
| @@ -28,9 +28,14 @@ namespace parallel { | |||
| std::string GetOpPythonPath(const OperatorName &op_name) { | |||
| // almost all ops are defined in two main paths | |||
| const std::string ops_module = OP_PATH; | |||
| const std::string inner_ops_module = INNER_OP_PATH; | |||
| py::module mod = py::module::import(common::SafeCStr(ops_module)); | |||
| py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); | |||
| if (!py::hasattr(mod, common::SafeCStr(op_name))) { | |||
| MS_LOG(EXCEPTION) << ops_module << " don't have op:" << op_name; | |||
| if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { | |||
| MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; | |||
| } | |||
| return inner_ops_module; | |||
| } | |||
| return ops_module; | |||
| } | |||
| @@ -56,6 +56,12 @@ Status GatherV2PInfo::GetAttrs() { | |||
| } | |||
| } | |||
| // target=CPU, axis must be 0 | |||
| if (target_ == "CPU" && axis_ != 0) { | |||
| MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -279,6 +285,11 @@ Status GatherV2PInfo::InferBias() { | |||
| int32_t rank = g_device_manager->global_rank(); | |||
| auto input_shape = inputs_shape_.at(0); | |||
| auto params_strategy = strategy_->GetInputDim().at(0); | |||
| // axis don't split | |||
| if (params_strategy.at(axis_) == 1) { | |||
| bias_ = 0; | |||
| return SUCCESS; | |||
| } | |||
| // params_size=1, axis=0 | |||
| if ((input_shape.size() == 1) && (axis_ == 0)) { | |||
| slice_size_ = input_shape.at(0) / params_strategy.at(0); | |||
| @@ -353,26 +364,35 @@ Status GatherV2PInfo::InferForwardCommunication() { | |||
| } | |||
| auto group_size = group_.GetDevNum(); | |||
| Attr attr_group; | |||
| // group size <= 8 | |||
| std::vector<int32_t> rank_list; | |||
| if (group_size <= 8) { | |||
| reduce_scatter_flag_ = false; | |||
| operator_name = HOST_REDUCE_SCATTER; | |||
| rank_list = GetRankFromGroup(group_); | |||
| attr_group = std::make_pair(GROUP, MakeValue(rank_list)); | |||
| if (host_reduce_scatter_) { | |||
| // group size <= 8 | |||
| std::vector<int32_t> rank_list; | |||
| if (group_size <= 8) { | |||
| reduce_scatter_flag_ = false; | |||
| operator_name = HOST_REDUCE_SCATTER; | |||
| rank_list = GetRankFromGroup(group_); | |||
| attr_group = std::make_pair(GROUP, MakeValue(rank_list)); | |||
| } else { | |||
| // group size > 8, don't support host reduce_scatter | |||
| reduce_scatter_flag_ = true; | |||
| split_num_ = SizeToInt(group_size / 8); | |||
| CheckGlobalDeviceManager(); | |||
| operator_name = REDUCE_SCATTER; | |||
| int32_t rank = g_device_manager->global_rank(); | |||
| size_t repeat = group_size / 8; | |||
| for (size_t i = 0; i < repeat; ++i) { | |||
| rank_list.push_back(rank + SizeToInt(i * 8)); | |||
| } | |||
| Group g = g_device_manager->CreateGroup(rank_list); | |||
| attr_group = std::make_pair(GROUP, MakeValue(g.name())); | |||
| } | |||
| } else { | |||
| // group size > 8 | |||
| reduce_scatter_flag_ = true; | |||
| split_num_ = SizeToInt(group_size / 8); | |||
| CheckGlobalDeviceManager(); | |||
| operator_name = REDUCE_SCATTER; | |||
| int32_t rank = g_device_manager->global_rank(); | |||
| size_t repeat = group_size / 8; | |||
| for (size_t i = 0; i < repeat; ++i) { | |||
| rank_list.push_back(rank + SizeToInt(i * 8)); | |||
| if (InferGroup() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Infer Group failed."; | |||
| return FAILED; | |||
| } | |||
| Group g = g_device_manager->CreateGroup(rank_list); | |||
| attr_group = std::make_pair(GROUP, MakeValue(g.name())); | |||
| attr_group = std::make_pair(GROUP, MakeValue(group_.name())); | |||
| } | |||
| Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | |||
| OperatorAttrs attrs = {attr_op, attr_group}; | |||
| @@ -446,8 +466,8 @@ Status GatherV2PInfo::ComputeReplaceOp() { | |||
| Attr param_offset = std::make_pair("offset", MakeValue(bias_)); | |||
| Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); | |||
| Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); | |||
| OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5), | |||
| std::make_pair(param_split_num, 6)}; | |||
| OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4), | |||
| std::make_pair(param_split_num, 5)}; | |||
| OperatorArgs args = std::make_pair(attrs, params); | |||
| Operator op = std::make_pair(op_name, args); | |||
| replace_op_.push_back(op); | |||
| @@ -70,6 +70,7 @@ class GatherV2PInfo : public OperatorInfo { | |||
| Group group_; | |||
| bool reduce_scatter_flag_ = false; | |||
| int32_t split_num_ = 1; | |||
| bool host_reduce_scatter_ = false; | |||
| }; | |||
| class SparseGatherV2Info : public GatherV2PInfo { | |||
| @@ -55,6 +55,7 @@ constexpr char REDUCE_OP_SUM[] = "sum"; | |||
| constexpr char REDUCE_OP_MAX[] = "max"; | |||
| constexpr char REDUCE_OP_MIN[] = "min"; | |||
| constexpr char OP_PATH[] = "mindspore.ops.operations"; | |||
| constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; | |||
| constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; | |||
| constexpr char GET_OP_FUNCTION[] = "_get_python_op"; | |||
| constexpr char KEEP_DIMS[] = "keep_dims"; | |||
| @@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st | |||
| std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | |||
| auto prim = GetValueNode<PrimitivePtr>(node->input(0)); | |||
| if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { | |||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)}; | |||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | |||
| } | |||
| if (!params.empty()) { | |||
| Param param_first = *(params.begin()); | |||
| @@ -184,7 +184,7 @@ def test_gatherv2_auto1(): | |||
| _executor.compile(net, x, y) | |||
| def need_fix_test_gatherv2_cpu0(): | |||
| def test_gatherv2_cpu0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| @@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0(): | |||
| _executor.compile(net, x, y) | |||
| def need_fix_test_gatherv2_cpu1(): | |||
| def test_gatherv2_cpu1(): | |||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((16, 1), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| @@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1(): | |||
| _executor.compile(net, x, y) | |||
| def need_fix_test_gatherv2_cpu2(): | |||
| def test_gatherv2_cpu2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| @@ -184,7 +184,7 @@ def test_gatherv2_auto1(): | |||
| _executor.compile(net, x, y) | |||
| def need_fix_test_gatherv2_cpu0(): | |||
| def test_gatherv2_cpu0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| @@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0(): | |||
| _executor.compile(net, x, y) | |||
| def need_fix_test_gatherv2_cpu1(): | |||
| def test_gatherv2_cpu1(): | |||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((16, 1), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| @@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1(): | |||
| _executor.compile(net, x, y) | |||
| def need_fix_test_gatherv2_cpu2(): | |||
| def test_gatherv2_cpu2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||