| @@ -350,6 +350,8 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { | |||||
| } | } | ||||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { | OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto attrs = prim->attrs(); | auto attrs = prim->attrs(); | ||||
| std::vector<Shapes> shape_list = ExtractShape(cnode); | std::vector<Shapes> shape_list = ExtractShape(cnode); | ||||
| if (shape_list.empty()) { | if (shape_list.empty()) { | ||||
| @@ -374,7 +374,6 @@ bool IsParallelCareNode(const CNodePtr& cnode) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto attrs = prim->attrs(); | |||||
| if (IsInBlackList(prim)) { | if (IsInBlackList(prim)) { | ||||
| MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); | MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); | ||||
| return false; | return false; | ||||
| @@ -1971,11 +1970,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(current_value); | MS_EXCEPTION_IF_NULL(current_value); | ||||
| PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); | PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(current_prim); | MS_EXCEPTION_IF_NULL(current_prim); | ||||
| <<<<<<< HEAD | |||||
| <<<<<<< HEAD | |||||
| ======= | |||||
| >>>>>>> fix_cast_bug | |||||
| // return -> cast | // return -> cast | ||||
| if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | ||||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | ||||
| @@ -1983,8 +1978,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { | |||||
| current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| } | } | ||||
| ======= | |||||
| >>>>>>> 回退 'Pull Request !17 : [AutoParallel]Fix bug in the case of two cast' | |||||
| // notice: the GetNext op has not input | // notice: the GetNext op has not input | ||||
| if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { | if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { | ||||
| MS_LOG(INFO) << "The loss is: " << current_prim->name(); | MS_LOG(INFO) << "The loss is: " << current_prim->name(); | ||||
| @@ -268,3 +268,32 @@ def test_cast_before_mirror3(): | |||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float16) | y = Tensor(np.ones([32, 64]), dtype=ms.float16) | ||||
| b = Tensor(np.ones([64, 64]), dtype=ms.float32) | b = Tensor(np.ones([64, 64]), dtype=ms.float32) | ||||
| _executor.compile(net, x, y, b) | _executor.compile(net, x, y, b) | ||||
| def test_mul_two_cast(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2, strategy3): | |||||
| super().__init__() | |||||
| self.mul = P.Mul().set_strategy(strategy1) | |||||
| self.mul2 = P.Mul().set_strategy(strategy2) | |||||
| self.cast = P.Cast().set_strategy(strategy3) | |||||
| self.cast2 = P.Cast().set_strategy(strategy3) | |||||
| def construct(self, x, y, b): | |||||
| out = self.mul(x, y) | |||||
| out = self.mul2(out, b) | |||||
| out = self.cast(out, ms.int32) | |||||
| out = self.cast2(out, ms.bool_) | |||||
| return out | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| strategy1 = ((2, 2), (2, 2)) | |||||
| strategy2 = ((8, 1), (8, 1)) | |||||
| strategy3 = ((8, 1), ) | |||||
| net = GradWrap(Net(strategy1, strategy2, strategy3)) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| b = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y, b) | |||||