Merge pull request !5661 from yao_yf/auto_parallel_reshape_fixtags/v1.0.0
| @@ -1565,24 +1565,33 @@ Status CostGraph::InitSelectedStrategy() { | |||
| auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) { | |||
| return edge->next_operator()->name() == reshape_info->next_operator_name(); | |||
| }); | |||
| if (pre_iter != in_edges.end()) { | |||
| bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name(); | |||
| if (reshape_is_first_op) { | |||
| reshape_info->InitSelectedStrategy(reshape_info->selected_strategy()); | |||
| } | |||
| if (pre_iter != in_edges.end() || reshape_is_first_op) { | |||
| MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); | |||
| int32_t pre_index = reshape_info->pre_operator_index(); | |||
| TensorInfo pre_info; | |||
| if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { | |||
| pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; | |||
| std::shared_ptr<OperatorInfo> pre_op_info; | |||
| if (reshape_is_first_op) { | |||
| pre_op_info = reshape_info; | |||
| pre_info = pre_op_info->inputs_tensor_info()[pre_index]; | |||
| } else { | |||
| pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; | |||
| pre_op_info = (*pre_iter)->prev_operator(); | |||
| pre_info = pre_op_info->outputs_tensor_info()[pre_index]; | |||
| } | |||
| reshape_info->SetInputLayout(pre_info.tensor_layout()); | |||
| Dimensions stra = pre_info.InferStrategy(); | |||
| if (stra.empty()) { | |||
| MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; | |||
| if (pre_iter != in_edges.end()) { | |||
| Dimensions stra = pre_info.InferStrategy(); | |||
| if (stra.empty()) { | |||
| MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; | |||
| } | |||
| Strategys stra_inputs = {stra}; | |||
| StrategyPtr reshape_stra = | |||
| std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); | |||
| reshape_info->set_strategy(reshape_stra); | |||
| } | |||
| Strategys stra_inputs = {stra}; | |||
| StrategyPtr reshape_stra = | |||
| std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); | |||
| reshape_info->set_strategy(reshape_stra); | |||
| } | |||
| if (next_iter != out_edges.end()) { | |||
| MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); | |||
| @@ -245,3 +245,50 @@ def test_reshape_auto_5(): | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y) | |||
| def test_reshape_auto_6(): | |||
| class NetWithLoss6(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss6, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| class GradWrap6(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap6, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return grad_all(self.network)(x, y) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.mul = P.Mul() | |||
| self.reshape = P.Reshape() | |||
| self.reduce_mean = P.ReduceMean() | |||
| self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight") | |||
| def construct(self, x, y): | |||
| out1 = x + self.wide_w | |||
| w = self.reshape(self.wide_w, (4, 1024)) | |||
| out1 = self.reduce_mean(out1, 1) | |||
| out1 = out1 - w | |||
| out2 = self.mul(y, w) | |||
| out = out1 + out2 | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32) | |||
| y = Tensor(np.ones([4, 1024,]), dtype=ms.float32) | |||
| net = GradWrap6(NetWithLoss6(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y) | |||