Merge pull request !2273 from lichen/fix_embeding_look_uptags/v0.5.0-beta
| @@ -28,9 +28,14 @@ namespace parallel { | |||||
| std::string GetOpPythonPath(const OperatorName &op_name) { | std::string GetOpPythonPath(const OperatorName &op_name) { | ||||
| // almost all ops are defined in two main paths | // almost all ops are defined in two main paths | ||||
| const std::string ops_module = OP_PATH; | const std::string ops_module = OP_PATH; | ||||
| const std::string inner_ops_module = INNER_OP_PATH; | |||||
| py::module mod = py::module::import(common::SafeCStr(ops_module)); | py::module mod = py::module::import(common::SafeCStr(ops_module)); | ||||
| py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); | |||||
| if (!py::hasattr(mod, common::SafeCStr(op_name))) { | if (!py::hasattr(mod, common::SafeCStr(op_name))) { | ||||
| MS_LOG(EXCEPTION) << ops_module << " don't have op:" << op_name; | |||||
| if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { | |||||
| MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; | |||||
| } | |||||
| return inner_ops_module; | |||||
| } | } | ||||
| return ops_module; | return ops_module; | ||||
| } | } | ||||
| @@ -56,6 +56,12 @@ Status GatherV2PInfo::GetAttrs() { | |||||
| } | } | ||||
| } | } | ||||
| // target=CPU, axis must be 0 | |||||
| if (target_ == "CPU" && axis_ != 0) { | |||||
| MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -279,6 +285,11 @@ Status GatherV2PInfo::InferBias() { | |||||
| int32_t rank = g_device_manager->global_rank(); | int32_t rank = g_device_manager->global_rank(); | ||||
| auto input_shape = inputs_shape_.at(0); | auto input_shape = inputs_shape_.at(0); | ||||
| auto params_strategy = strategy_->GetInputDim().at(0); | auto params_strategy = strategy_->GetInputDim().at(0); | ||||
| // axis don't split | |||||
| if (params_strategy.at(axis_) == 1) { | |||||
| bias_ = 0; | |||||
| return SUCCESS; | |||||
| } | |||||
| // params_size=1, axis=0 | // params_size=1, axis=0 | ||||
| if ((input_shape.size() == 1) && (axis_ == 0)) { | if ((input_shape.size() == 1) && (axis_ == 0)) { | ||||
| slice_size_ = input_shape.at(0) / params_strategy.at(0); | slice_size_ = input_shape.at(0) / params_strategy.at(0); | ||||
| @@ -353,26 +364,35 @@ Status GatherV2PInfo::InferForwardCommunication() { | |||||
| } | } | ||||
| auto group_size = group_.GetDevNum(); | auto group_size = group_.GetDevNum(); | ||||
| Attr attr_group; | Attr attr_group; | ||||
| // group size <= 8 | |||||
| std::vector<int32_t> rank_list; | |||||
| if (group_size <= 8) { | |||||
| reduce_scatter_flag_ = false; | |||||
| operator_name = HOST_REDUCE_SCATTER; | |||||
| rank_list = GetRankFromGroup(group_); | |||||
| attr_group = std::make_pair(GROUP, MakeValue(rank_list)); | |||||
| if (host_reduce_scatter_) { | |||||
| // group size <= 8 | |||||
| std::vector<int32_t> rank_list; | |||||
| if (group_size <= 8) { | |||||
| reduce_scatter_flag_ = false; | |||||
| operator_name = HOST_REDUCE_SCATTER; | |||||
| rank_list = GetRankFromGroup(group_); | |||||
| attr_group = std::make_pair(GROUP, MakeValue(rank_list)); | |||||
| } else { | |||||
| // group size > 8, don't support host reduce_scatter | |||||
| reduce_scatter_flag_ = true; | |||||
| split_num_ = SizeToInt(group_size / 8); | |||||
| CheckGlobalDeviceManager(); | |||||
| operator_name = REDUCE_SCATTER; | |||||
| int32_t rank = g_device_manager->global_rank(); | |||||
| size_t repeat = group_size / 8; | |||||
| for (size_t i = 0; i < repeat; ++i) { | |||||
| rank_list.push_back(rank + SizeToInt(i * 8)); | |||||
| } | |||||
| Group g = g_device_manager->CreateGroup(rank_list); | |||||
| attr_group = std::make_pair(GROUP, MakeValue(g.name())); | |||||
| } | |||||
| } else { | } else { | ||||
| // group size > 8 | |||||
| reduce_scatter_flag_ = true; | |||||
| split_num_ = SizeToInt(group_size / 8); | |||||
| CheckGlobalDeviceManager(); | |||||
| operator_name = REDUCE_SCATTER; | operator_name = REDUCE_SCATTER; | ||||
| int32_t rank = g_device_manager->global_rank(); | |||||
| size_t repeat = group_size / 8; | |||||
| for (size_t i = 0; i < repeat; ++i) { | |||||
| rank_list.push_back(rank + SizeToInt(i * 8)); | |||||
| if (InferGroup() != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Infer Group failed."; | |||||
| return FAILED; | |||||
| } | } | ||||
| Group g = g_device_manager->CreateGroup(rank_list); | |||||
| attr_group = std::make_pair(GROUP, MakeValue(g.name())); | |||||
| attr_group = std::make_pair(GROUP, MakeValue(group_.name())); | |||||
| } | } | ||||
| Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | ||||
| OperatorAttrs attrs = {attr_op, attr_group}; | OperatorAttrs attrs = {attr_op, attr_group}; | ||||
| @@ -446,8 +466,8 @@ Status GatherV2PInfo::ComputeReplaceOp() { | |||||
| Attr param_offset = std::make_pair("offset", MakeValue(bias_)); | Attr param_offset = std::make_pair("offset", MakeValue(bias_)); | ||||
| Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); | Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); | ||||
| Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); | Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); | ||||
| OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5), | |||||
| std::make_pair(param_split_num, 6)}; | |||||
| OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4), | |||||
| std::make_pair(param_split_num, 5)}; | |||||
| OperatorArgs args = std::make_pair(attrs, params); | OperatorArgs args = std::make_pair(attrs, params); | ||||
| Operator op = std::make_pair(op_name, args); | Operator op = std::make_pair(op_name, args); | ||||
| replace_op_.push_back(op); | replace_op_.push_back(op); | ||||
| @@ -70,6 +70,7 @@ class GatherV2PInfo : public OperatorInfo { | |||||
| Group group_; | Group group_; | ||||
| bool reduce_scatter_flag_ = false; | bool reduce_scatter_flag_ = false; | ||||
| int32_t split_num_ = 1; | int32_t split_num_ = 1; | ||||
| bool host_reduce_scatter_ = false; | |||||
| }; | }; | ||||
| class SparseGatherV2Info : public GatherV2PInfo { | class SparseGatherV2Info : public GatherV2PInfo { | ||||
| @@ -55,6 +55,7 @@ constexpr char REDUCE_OP_SUM[] = "sum"; | |||||
| constexpr char REDUCE_OP_MAX[] = "max"; | constexpr char REDUCE_OP_MAX[] = "max"; | ||||
| constexpr char REDUCE_OP_MIN[] = "min"; | constexpr char REDUCE_OP_MIN[] = "min"; | ||||
| constexpr char OP_PATH[] = "mindspore.ops.operations"; | constexpr char OP_PATH[] = "mindspore.ops.operations"; | ||||
| constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; | |||||
| constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; | constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; | ||||
| constexpr char GET_OP_FUNCTION[] = "_get_python_op"; | constexpr char GET_OP_FUNCTION[] = "_get_python_op"; | ||||
| constexpr char KEEP_DIMS[] = "keep_dims"; | constexpr char KEEP_DIMS[] = "keep_dims"; | ||||
| @@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st | |||||
| std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | ||||
| auto prim = GetValueNode<PrimitivePtr>(node->input(0)); | auto prim = GetValueNode<PrimitivePtr>(node->input(0)); | ||||
| if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { | if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { | ||||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)}; | |||||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | |||||
| } | } | ||||
| if (!params.empty()) { | if (!params.empty()) { | ||||
| Param param_first = *(params.begin()); | Param param_first = *(params.begin()); | ||||
| @@ -184,7 +184,7 @@ def test_gatherv2_auto1(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def need_fix_test_gatherv2_cpu0(): | |||||
| def test_gatherv2_cpu0(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((8, 1), (1, 1)) | strategy1 = ((8, 1), (1, 1)) | ||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | strategy2 = ((4, 2, 1), (4, 2, 1)) | ||||
| @@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def need_fix_test_gatherv2_cpu1(): | |||||
| def test_gatherv2_cpu1(): | |||||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((16, 1), (1, 1)) | strategy1 = ((16, 1), (1, 1)) | ||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | strategy2 = ((4, 2, 1), (4, 2, 1)) | ||||
| @@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def need_fix_test_gatherv2_cpu2(): | |||||
| def test_gatherv2_cpu2(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((1, 8), (1, 1)) | strategy1 = ((1, 8), (1, 1)) | ||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | strategy2 = ((4, 2, 1), (4, 2, 1)) | ||||
| @@ -184,7 +184,7 @@ def test_gatherv2_auto1(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def need_fix_test_gatherv2_cpu0(): | |||||
| def test_gatherv2_cpu0(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((8, 1), (1, 1)) | strategy1 = ((8, 1), (1, 1)) | ||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | strategy2 = ((4, 2, 1), (4, 2, 1)) | ||||
| @@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def need_fix_test_gatherv2_cpu1(): | |||||
| def test_gatherv2_cpu1(): | |||||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((16, 1), (1, 1)) | strategy1 = ((16, 1), (1, 1)) | ||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | strategy2 = ((4, 2, 1), (4, 2, 1)) | ||||
| @@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def need_fix_test_gatherv2_cpu2(): | |||||
| def test_gatherv2_cpu2(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((1, 8), (1, 1)) | strategy1 = ((1, 8), (1, 1)) | ||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | strategy2 = ((4, 2, 1), (4, 2, 1)) | ||||