| @@ -346,6 +346,8 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { | |||||
| } | } | ||||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { | OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto attrs = prim->attrs(); | auto attrs = prim->attrs(); | ||||
| std::vector<Shapes> shape_list = ExtractShape(cnode); | std::vector<Shapes> shape_list = ExtractShape(cnode); | ||||
| if (shape_list.empty()) { | if (shape_list.empty()) { | ||||
| @@ -381,8 +383,8 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| operator_info->set_outputs_dtype(cnode->Type()); | operator_info->set_outputs_dtype(cnode->Type()); | ||||
| operator_info->set_cnode(cnode); | operator_info->set_cnode(cnode); | ||||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | // If no strategy has been configured for this operator, then candidate strategies are generated for | ||||
| // auto-strategy searching | |||||
| if (!StrategyFound(attrs)) { | |||||
| // auto-strategy searchingm if this primitive is Cast, we ignore the user-specified strategy | |||||
| if (!StrategyFound(attrs) || prim->name() == CAST) { | |||||
| // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | ||||
| // BatchParallelInfo operator | // BatchParallelInfo operator | ||||
| operator_info->ComputeBatchSplitFlagList(); | operator_info->ComputeBatchSplitFlagList(); | ||||
| @@ -370,7 +370,6 @@ bool IsParallelCareNode(const CNodePtr& cnode) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto attrs = prim->attrs(); | |||||
| if (IsInBlackList(prim)) { | if (IsInBlackList(prim)) { | ||||
| MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); | MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); | ||||
| return false; | return false; | ||||
| @@ -379,10 +378,8 @@ bool IsParallelCareNode(const CNodePtr& cnode) { | |||||
| if (prim->name() == GET_NEXT) { | if (prim->name() == GET_NEXT) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| if ((prim->name() == CAST)) { | |||||
| if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) { | |||||
| return false; | |||||
| } | |||||
| if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { | |||||
| return false; | |||||
| } | } | ||||
| return cnode->in_forward_flag(); | return cnode->in_forward_flag(); | ||||
| @@ -653,6 +650,14 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { | |||||
| LossNodeInfo node_info; | LossNodeInfo node_info; | ||||
| // return -> cast | |||||
| auto pre_cnode = pre_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | |||||
| auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||||
| if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | |||||
| pre_node = pre_cnode->input(1); | |||||
| } | |||||
| // return -> loss | // return -> loss | ||||
| if (pre_node == loss_node) { | if (pre_node == loss_node) { | ||||
| node_info.has_tuple_getitem = false; | node_info.has_tuple_getitem = false; | ||||
| @@ -1947,6 +1952,14 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(current_value); | MS_EXCEPTION_IF_NULL(current_value); | ||||
| PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); | PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(current_prim); | MS_EXCEPTION_IF_NULL(current_prim); | ||||
| // return -> cast | |||||
| if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | |||||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | |||||
| current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||||
| } | |||||
| // notice: the GetNext op has not input | // notice: the GetNext op has not input | ||||
| if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { | if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { | ||||
| MS_LOG(INFO) << "The loss is: " << current_prim->name(); | MS_LOG(INFO) << "The loss is: " << current_prim->name(); | ||||
| @@ -192,7 +192,6 @@ def test_cast_before_mirror(): | |||||
| net = GradWrap(NetWithLoss(Net(strategy1))) | net = GradWrap(NetWithLoss(Net(strategy1))) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float32) | x = Tensor(np.ones([128, 32]), dtype=ms.float32) | ||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float32) | y = Tensor(np.ones([32, 64]), dtype=ms.float32) | ||||
| b = Tensor(np.ones([64, 64]), dtype=ms.float16) | b = Tensor(np.ones([64, 64]), dtype=ms.float16) | ||||
| @@ -217,7 +216,6 @@ def test_cast_before_mirror1(): | |||||
| net = GradWrap(NetWithLoss(Net(strategy1))) | net = GradWrap(NetWithLoss(Net(strategy1))) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float16) | x = Tensor(np.ones([128, 32]), dtype=ms.float16) | ||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float16) | y = Tensor(np.ones([32, 64]), dtype=ms.float16) | ||||
| b = Tensor(np.ones([64, 64]), dtype=ms.float32) | b = Tensor(np.ones([64, 64]), dtype=ms.float32) | ||||
| @@ -242,7 +240,6 @@ def test_cast_before_mirror2(): | |||||
| net = GradWrap(NetWithLoss(Net(strategy1))) | net = GradWrap(NetWithLoss(Net(strategy1))) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float16) | x = Tensor(np.ones([128, 32]), dtype=ms.float16) | ||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float16) | y = Tensor(np.ones([32, 64]), dtype=ms.float16) | ||||
| b = Tensor(np.ones([64, 64]), dtype=ms.float32) | b = Tensor(np.ones([64, 64]), dtype=ms.float32) | ||||
| @@ -267,8 +264,36 @@ def test_cast_before_mirror3(): | |||||
| net = GradWrap(NetWithLoss(Net(strategy1))) | net = GradWrap(NetWithLoss(Net(strategy1))) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float16) | x = Tensor(np.ones([128, 32]), dtype=ms.float16) | ||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float16) | y = Tensor(np.ones([32, 64]), dtype=ms.float16) | ||||
| b = Tensor(np.ones([64, 64]), dtype=ms.float32) | b = Tensor(np.ones([64, 64]), dtype=ms.float32) | ||||
| _executor.compile(net, x, y, b) | _executor.compile(net, x, y, b) | ||||
| def test_mul_two_cast(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2, strategy3): | |||||
| super().__init__() | |||||
| self.mul = P.Mul().set_strategy(strategy1) | |||||
| self.mul2 = P.Mul().set_strategy(strategy2) | |||||
| self.cast = P.Cast().set_strategy(strategy3) | |||||
| self.cast2 = P.Cast().set_strategy(strategy3) | |||||
| def construct(self, x, y, b): | |||||
| out = self.mul(x, y) | |||||
| out = self.mul2(out, b) | |||||
| out = self.cast(out, ms.int32) | |||||
| out = self.cast2(out, ms.bool_) | |||||
| return out | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| strategy1 = ((2, 2), (2, 2)) | |||||
| strategy2 = ((8, 1), (8, 1)) | |||||
| strategy3 = ((8, 1), ) | |||||
| net = GradWrap(Net(strategy1, strategy2, strategy3)) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| b = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y, b) | |||||