Merge pull request !4723 from yangzhenzhang/concat-more-than-3-tensorstags/v0.7.0-beta
| @@ -223,17 +223,32 @@ Status ConcatInfo::GenerateStrategies(int32_t stage_id) { | |||||
| input_split.push_back(1); | input_split.push_back(1); | ||||
| } | } | ||||
| } | } | ||||
| Shapes splittable_inputs; | |||||
| for (size_t i = 0; i < inputs_shape_.size(); ++i) { | |||||
| splittable_inputs.push_back(input_split); | |||||
| } | |||||
| // to generate the first input's strategy | |||||
| Shapes splittable_input = {input_split}; | |||||
| Shapes tmp_inputs_shape = {inputs_shape_[0]}; | |||||
| std::vector<StrategyPtr> sp_vector; | std::vector<StrategyPtr> sp_vector; | ||||
| is_auto_parallel_ = true; | is_auto_parallel_ = true; | ||||
| if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { | |||||
| if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Generate strategies failed"; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // the others strategies are equal to the first input's strategy | |||||
| for (auto &sp : sp_vector) { | |||||
| if ((sp == nullptr) || sp->GetInputDim().empty()) { | |||||
| MS_LOG(ERROR) << name_ << ": The strategy is null or empty"; | |||||
| return FAILED; | |||||
| } | |||||
| Strategys tmp_strategy; | |||||
| Dimensions first_input_strategy = sp->GetInputDim()[0]; | |||||
| for (size_t i = 0; i < inputs_shape_.size(); ++i) { | |||||
| tmp_strategy.push_back(first_input_strategy); | |||||
| } | |||||
| sp->ResetInputs(tmp_strategy); | |||||
| } | |||||
| size_t success = 0; | size_t success = 0; | ||||
| for (auto &sp : sp_vector) { | for (auto &sp : sp_vector) { | ||||
| PrintStrategy(sp); | PrintStrategy(sp); | ||||
| @@ -111,7 +111,6 @@ Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| Dimensions strategy_value = stra[0]; | Dimensions strategy_value = stra[0]; | ||||
| bool has_split = std::any_of(strategy_value.begin(), strategy_value.end(), [](int32_t v) { return v > 1; }); | bool has_split = std::any_of(strategy_value.begin(), strategy_value.end(), [](int32_t v) { return v > 1; }); | ||||
| if (has_split && has_mask_) { | if (has_split && has_mask_) { | ||||
| MS_LOG(ERROR) << name_ << ": When there is a mask, the input is not supported to be split"; | MS_LOG(ERROR) << name_ << ": When there is a mask, the input is not supported to be split"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -50,12 +50,34 @@ class Net2(Cell): | |||||
| return out | return out | ||||
| class Net3(Cell): | |||||
| def __init__(self, weight, weight2, weight3, strategy1=None, strategy2=None, is_parameter=True): | |||||
| super().__init__() | |||||
| self.concat = P.Concat(axis=0).set_strategy(strategy1) | |||||
| if is_parameter: | |||||
| self.weight = Parameter(weight, "w1") | |||||
| else: | |||||
| self.weight = weight | |||||
| self.mul = P.Mul().set_strategy(strategy2) | |||||
| self.weight2 = Parameter(weight2, "w2") | |||||
| self.weight3 = Parameter(weight3, "w3") | |||||
| def construct(self, x, b): | |||||
| out = self.concat((self.weight, self.weight2, self.weight3)) | |||||
| out = self.mul(x, out) | |||||
| return out | |||||
| _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | ||||
| _w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32) | _w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32) | ||||
| _w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32) | _w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32) | ||||
| _w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32) | _w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32) | ||||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | ||||
| w1 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32) | |||||
| w2 = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) | |||||
| w3 = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) | |||||
| def compile_net(net): | def compile_net(net): | ||||
| context.set_context(save_graphs=True) | context.set_context(save_graphs=True) | ||||
| @@ -126,3 +148,9 @@ def test_concat_auto_parallel2(): | |||||
| strategy2 = None | strategy2 = None | ||||
| net = Net2(_w3, strategy1, strategy2, axis=1) | net = Net2(_w3, strategy1, strategy2, axis=1) | ||||
| compile_net(net) | compile_net(net) | ||||
| def test_concat_auto_parallel_3_tensor(): | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||||
| net = Net3(w1, w2, w3) | |||||
| compile_net(net) | |||||