Merge pull request !4356 from yangzhenzhang/update-field-splittags/v0.7.0-beta
| @@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||
| auto device_arrangement = tensor_layout->device_arrangement().array(); | |||
| auto tensor_map = tensor_layout->tensor_map().array(); | |||
| auto slice_shape = tensor_layout->slice_shape().array(); | |||
| int32_t _field_size = tensor_layout->get_field_size(); | |||
| Shape field_size; | |||
| if (_field_size != 0) { | |||
| field_size.push_back(_field_size); | |||
| Shape field_size = {tensor_layout->get_field_size()}; | |||
| Shape uniform_split; | |||
| if (tensor_layout->uniform_split()) { | |||
| uniform_split.push_back(1); | |||
| } else { | |||
| field_size = {0}; | |||
| uniform_split.push_back(0); | |||
| } | |||
| std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size}; | |||
| std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split}; | |||
| dict[py::str(name)] = layout; | |||
| MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); | |||
| } | |||
| @@ -27,6 +27,92 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() { | |||
| auto manual_split_without_offset_iter = attrs_.find("manual_split"); | |||
| if (manual_split_without_offset_iter != attrs_.end()) { | |||
| manual_split_ = true; | |||
| MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second); | |||
| if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue"; | |||
| return FAILED; | |||
| } | |||
| std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value(); | |||
| MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString(); | |||
| int64_t offset = 0; | |||
| for (auto &ele : value_vector) { | |||
| index_offsets_.push_back(offset); | |||
| if (!ele->isa<Int32Imm>()) { | |||
| MS_LOG(ERROR) << name_ << ": The element of manual split must be int"; | |||
| return FAILED; | |||
| } | |||
| int64_t param_split_shape = static_cast<int64_t>(GetValue<int>(ele)); | |||
| if (param_split_shape <= 0) { | |||
| MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape; | |||
| return FAILED; | |||
| } | |||
| param_split_shapes_.push_back(param_split_shape); | |||
| offset += param_split_shape; | |||
| } | |||
| if (param_split_shapes_.empty()) { | |||
| MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info"; | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::GetManualSplitAttr() { | |||
| auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset"); | |||
| if (manual_split_with_offset_iter != attrs_.end()) { | |||
| manual_split_ = true; | |||
| auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>(); | |||
| if (var == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString(); | |||
| for (auto &ele : var->value()) { | |||
| if (!ele->isa<ValueSequeue>()) { | |||
| MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue"; | |||
| return FAILED; | |||
| } | |||
| std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value(); | |||
| if (value_vector.size() != 2) { | |||
| MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2"; | |||
| return FAILED; | |||
| } | |||
| int64_t param_split_row = static_cast<int64_t>(GetValue<int>(value_vector[0])); | |||
| int64_t offset = static_cast<int64_t>(GetValue<int>(value_vector[1])); | |||
| if ((param_split_row <= 0) || (offset < 0)) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The value of param split shape must be positive, and the offset must larger or equal to 0"; | |||
| return FAILED; | |||
| } | |||
| param_split_shapes_.push_back(param_split_row); | |||
| index_offsets_.push_back(offset); | |||
| } | |||
| if (param_split_shapes_.empty()) { | |||
| MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info"; | |||
| return FAILED; | |||
| } | |||
| if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) { | |||
| MS_LOG(ERROR) << name_ << ": Index offset must not less than 0"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| if (GetManualSplitWithoutOffsetAttr() != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::GetAttrs() { | |||
| // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. | |||
| if (target_ != CPU) { | |||
| @@ -53,58 +139,76 @@ Status GatherV2PInfo::GetAttrs() { | |||
| if (target_iter->second->isa<StringImm>()) { | |||
| target_ = target_iter->second->cast<StringImmPtr>()->value(); | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << " : The value of target is not a string."; | |||
| MS_LOG(ERROR) << name_ << ": The value of target is not a string."; | |||
| } | |||
| } | |||
| auto manual_split_iter = attrs_.find("manual_split"); | |||
| if (manual_split_iter != attrs_.end()) { | |||
| param_split_shapes_.clear(); | |||
| manual_split_ = true; | |||
| auto var = manual_split_iter->second->cast<ValueTuplePtr>(); | |||
| MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString(); | |||
| if (var->size() > 0) { | |||
| std::vector<ValuePtr> elements = var->value(); | |||
| for (auto &ele : elements) { | |||
| if (ele->isa<ValueSequeue>()) { | |||
| auto value_tuple = ele->cast<ValueTuplePtr>(); | |||
| std::vector<ValuePtr> value_vector = value_tuple->value(); | |||
| if (value_vector.size() != 2) { | |||
| MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; | |||
| return FAILED; | |||
| } | |||
| param_split_shapes_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[0]))); | |||
| index_offsets_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[1]))); | |||
| } else { | |||
| MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; | |||
| return FAILED; | |||
| } | |||
| } | |||
| if (param_split_shapes_.empty()) { | |||
| MS_LOG(ERROR) << "Failed to extract param split strategy."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| if (GetManualSplitAttr() != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (manual_split_ && (axis_ != 0)) { | |||
| MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::CheckManualSplit() { | |||
| auto param_shape = inputs_shape_.at(0); | |||
| int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, | |||
| [](int64_t s, int64_t shape) { return s + shape; }); | |||
| if (split_shape_sum < param_shape.at(0)) { | |||
| MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; | |||
| Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { | |||
| if (strategy.size() != 2) { | |||
| MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size(); | |||
| return FAILED; | |||
| } | |||
| Dimensions param_strategy = strategy[0]; | |||
| Dimensions indices_strategy = strategy[1]; | |||
| if (param_strategy.size() != 2 || indices_strategy.size() != 2) { | |||
| MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2"; | |||
| return FAILED; | |||
| } | |||
| if (indices_strategy[0] != 1) { | |||
| MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0]; | |||
| return FAILED; | |||
| } | |||
| if (param_strategy[0] != indices_strategy[1]) { | |||
| MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]"; | |||
| return FAILED; | |||
| } | |||
| if (indices_strategy[1] != SizeToInt(param_split_shapes_.size())) { | |||
| MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size"; | |||
| return FAILED; | |||
| } | |||
| int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1]; | |||
| bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(), | |||
| [&min_param_slice_row](int64_t v) { return v < min_param_slice_row; }); | |||
| if (invalid) { | |||
| MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num"; | |||
| return FAILED; | |||
| } | |||
| if (inputs_shape_[0][0] < inputs_shape_[1][1]) { | |||
| MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column"; | |||
| return FAILED; | |||
| } | |||
| if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) { | |||
| MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; | |||
| // Don't support repeated calc | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); | |||
| if (IntToSize(product_p) < dev_num) { | |||
| MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; | |||
| return FAILED; | |||
| } | |||
| int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, | |||
| [](int64_t s, int64_t shape) { return s + shape; }); | |||
| if (split_shape_sum != inputs_shape_[0][0]) { | |||
| MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -147,7 +251,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| if (manual_split_) { | |||
| if (CheckManualSplit() != SUCCESS) { | |||
| if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| // when using manual_split, no need to check belowings. | |||
| @@ -343,14 +447,15 @@ Status GatherV2PInfo::InferTensorInfo() { | |||
| SUCCESS)) { | |||
| return FAILED; | |||
| } | |||
| if (manual_split_) { | |||
| input_tensor_layout.set_uniform_split(false); | |||
| } | |||
| // infer tensor info | |||
| TensorInfo input_tensor_info(input_tensor_layout); | |||
| TensorInfo input_index_info(input_index_layout); | |||
| TensorInfo output_tensor_info(output_tensor_layout); | |||
| Shape slice_shape = input_tensor_info.slice_shape(); | |||
| MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape); | |||
| inputs_tensor_info_.push_back(input_tensor_info); | |||
| inputs_tensor_info_.push_back(input_index_info); | |||
| outputs_tensor_info_.push_back(output_tensor_info); | |||
| @@ -392,9 +497,17 @@ Status GatherV2PInfo::InferBias() { | |||
| Status GatherV2PInfo::InferOffset() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t rank = g_device_manager->global_rank(); | |||
| if (rank < index_offsets_.size()) { | |||
| index_offset_ = index_offsets_.at(rank); | |||
| MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; | |||
| MS_EXCEPTION_IF_NULL(strategy_); | |||
| auto param_strategy = strategy_->GetInputDim()[0]; | |||
| if (param_strategy.size() != 2) { | |||
| MS_LOG(ERROR) << "The size of param strategy must be 2"; | |||
| return FAILED; | |||
| } | |||
| size_t index = rank / param_strategy[1]; | |||
| if (index < index_offsets_.size()) { | |||
| index_offset_ = index_offsets_[index]; | |||
| MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; | |||
| return SUCCESS; | |||
| } | |||
| @@ -524,8 +637,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { | |||
| if (manual_split_ && target_ != CPU) { | |||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; | |||
| return nullptr; | |||
| MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; | |||
| } | |||
| return replace_graph_; | |||
| } | |||
| @@ -536,8 +648,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { | |||
| return nullptr; | |||
| } | |||
| if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; | |||
| return nullptr; | |||
| MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; | |||
| } | |||
| return replace_graph_; | |||
| } | |||
| @@ -614,6 +725,13 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { | |||
| if (GetAttrs() != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (manual_split_) { | |||
| MS_LOG(ERROR) << name_ << ": Manual split does not support to search strategy"; | |||
| return FAILED; | |||
| } | |||
| is_auto_parallel_ = true; | |||
| Shape input0_split(inputs_shape_[0].size(), 1); | |||
| Shape input1_split(inputs_shape_[1].size(), 1); | |||
| @@ -621,14 +739,14 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { | |||
| std::vector<StrategyPtr> sp_vector; | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; | |||
| MS_LOG(ERROR) << name_ << ": Generate strategies for independent inputs() failed."; | |||
| return FAILED; | |||
| } | |||
| size_t success = 0; | |||
| for (auto &sp : sp_vector) { | |||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||
| success++; | |||
| MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; | |||
| MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; | |||
| PrintStrategy(sp); | |||
| } | |||
| } | |||
| @@ -636,6 +754,12 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { | |||
| } | |||
| std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() { | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | |||
| } | |||
| if (manual_split_) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy"; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions param_strategy(inputs_shape_[0].size(), 1); | |||
| @@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo { | |||
| Status GetAttrs() override; | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| Status CheckManualSplit(); | |||
| Status CheckManualSplit(const Strategys &strategy); | |||
| Status GetManualSplitAttr(); | |||
| Status GetManualSplitWithoutOffsetAttr(); | |||
| Status ComputeReplaceOp(); | |||
| Status InferBias(); | |||
| Status InferOffset(); | |||
| @@ -48,6 +48,10 @@ class TensorLayout { | |||
| void set_field_size(int32_t field_size) { field_size_ = field_size; } | |||
| bool uniform_split() const { return uniform_split_; } | |||
| void set_uniform_split(bool flag) { uniform_split_ = flag; } | |||
| Arrangement device_arrangement() const { return device_arrangement_; } | |||
| Map tensor_map() const { return tensor_map_; } | |||
| @@ -104,6 +108,7 @@ class TensorLayout { | |||
| Arrangement tensor_shape_; | |||
| bool skip_redistribution_ = false; | |||
| int32_t field_size_ = 0; | |||
| bool uniform_split_ = true; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout): | |||
| """ | |||
| if not isinstance(layout, list): | |||
| raise TypeError("The layout should be list! layout is {}".format(layout)) | |||
| if len(layout) < 3: | |||
| raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) | |||
| if len(layout) < 5: | |||
| raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout)) | |||
| dev_mat = layout[0] | |||
| tensor_map = layout[1] | |||
| uniform_split = layout[4] | |||
| if uniform_split[0] == 0: | |||
| raise RuntimeError("The load tensor only support uniform split now") | |||
| if tensor.size() == 1: | |||
| return tensor | |||
| return _load_tensor(tensor, dev_mat, tensor_map) | |||
| @@ -49,8 +49,8 @@ def test_get_parameter_layout(): | |||
| net.set_auto_parallel() | |||
| exe = me._executor | |||
| exe.compile(net, x, phase='train', auto_parallel_mode=True) | |||
| x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | |||
| # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | |||
| assert net.parameter_layout_dict == expect_dict | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.common.api import _executor | |||
| @@ -22,40 +23,170 @@ from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import initializer | |||
| class Net(Cell): | |||
| def __init__(self, strategy1=None, strategy2=None, strategy3=None): | |||
| def __init__(self, | |||
| strategy1=None, | |||
| strategy2=None, | |||
| strategy3=None, | |||
| axis=0, | |||
| init_flag=True, | |||
| split_tuple=(4, 4), | |||
| split_string="manual_split", | |||
| param_shape=(8, 8)): | |||
| super().__init__() | |||
| self.gatherv2 = P.GatherV2().set_strategy(strategy1) | |||
| self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1))) | |||
| self.gatherv2.add_prim_attr(split_string, split_tuple) | |||
| self.mul = P.Mul().set_strategy(strategy2) | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul().set_strategy(strategy3) | |||
| self.matmul.add_prim_attr("forward_reduce_scatter", True) | |||
| self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param") | |||
| self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight") | |||
| self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight") | |||
| if init_flag: | |||
| self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param") | |||
| else: | |||
| self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param") | |||
| self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight") | |||
| self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight") | |||
| self.axis = axis | |||
| def construct(self, x, b): | |||
| out = self.gatherv2(self.param, x, 0) | |||
| out = self.gatherv2(self.param, x, self.axis) | |||
| out = self.mul(out, self.mul_weight) | |||
| out = self.reshape(out, (2, 256)) | |||
| out = self.reshape(out, (8, 64)) | |||
| out = self.matmul(out, self.matmul_weight) | |||
| return out | |||
| _x = Tensor(np.ones([2, 4]), dtype=ms.int32) | |||
| _x = Tensor(np.ones([8, 8]), dtype=ms.int32) | |||
| _b = Tensor(np.ones([64, 8]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| context.set_context(save_graphs=True) | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| _executor.compile(train_net, _x, _b) | |||
| _executor.compile(train_net, _x, _b, auto_parallel_mode=True) | |||
| context.reset_auto_parallel_context() | |||
| def test_neg_data_parallel(): | |||
| context.set_context(save_graphs=True) | |||
| def test_normal_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) | |||
| strategy1 = ((2, 1), (1, 2)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3) | |||
| compile_net(net) | |||
| def test_normal_split2(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) | |||
| strategy1 = ((4, 1), (1, 4)) | |||
| strategy2 = ((1, 4, 1), (1, 4, 1)) | |||
| strategy3 = ((1, 4), (4, 1)) | |||
| net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) | |||
| compile_net(net) | |||
| def test_normal_split3(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17) | |||
| strategy1 = ((4, 8), (1, 4)) | |||
| strategy2 = ((1, 4, 8), (1, 4, 8)) | |||
| strategy3 = ((1, 32), (32, 1)) | |||
| net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) | |||
| compile_net(net) | |||
| def test_normal_split_with_offset(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) | |||
| strategy1 = ((2, 1), (1, 2)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4))) | |||
| compile_net(net) | |||
| def test_auto_parallel_error(): | |||
| context.set_context(save_graphs=True) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0) | |||
| net = Net() | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_axis_error(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) | |||
| strategy1 = ((2, 1), (1, 2)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3, axis=1) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_strategy_error(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((4, 1), (8, 1)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_strategy_error2(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((4, 1), (1, 8)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_strategy_error3(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((2, 1), (1, 2)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_strategy_error4(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) | |||
| strategy1 = ((2, 8), (1, 2)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_strategy_error5(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) | |||
| strategy1 = ((4, 1), (1, 4)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_split_tuple_error(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) | |||
| strategy1 = ((2, 1), (1, 2)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5))) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_parameter_use_tensor_error(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) | |||
| strategy1 = ((2, 1), (1, 2)) | |||
| strategy2 = ((1, 2, 1), (1, 2, 1)) | |||
| strategy3 = ((1, 2), (2, 1)) | |||
| net = Net(strategy1, strategy2, strategy3, init_flag=False) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||