|
|
|
@@ -181,7 +181,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { |
|
|
|
} |
|
|
|
PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0)); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node_prim); |
|
|
|
if (value_node_prim->name() == prim::kTupleGetitem) { |
|
|
|
if (value_node_prim->name() == prim::kTupleGetItem) { |
|
|
|
if (uses_set.size() > 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size(); |
|
|
|
} |
|
|
|
@@ -279,7 +279,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func |
|
|
|
TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, |
|
|
|
const OperatorInfoPtr &distribute_operator) { |
|
|
|
TensorInfo tensorinfo_in; |
|
|
|
if (middle_prim->name() == prim::kTupleGetitem) { |
|
|
|
if (middle_prim->name() == prim::kTupleGetItem) { |
|
|
|
auto value_node = middle_node->input(2)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
size_t index_s = LongToSize(GetValue<int64_t>(value_node->value())); |
|
|
|
@@ -473,7 +473,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ |
|
|
|
MS_EXCEPTION_IF_NULL(current_value); |
|
|
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(current_prim); |
|
|
|
insert_node_new = ((current_prim->name() == prim::kTupleGetitem) ? node : insert_node); |
|
|
|
insert_node_new = ((current_prim->name() == prim::kTupleGetItem) ? node : insert_node); |
|
|
|
} else { |
|
|
|
insert_node_new = insert_node; |
|
|
|
} |
|
|
|
@@ -1964,7 +1964,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
if (prim->name() == prim::kTupleGetitem) { |
|
|
|
if (prim->name() == prim::kTupleGetItem) { |
|
|
|
auto tuple_index = GetTupleGetItemIndex(cnode); |
|
|
|
auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index)); |
|
|
|
if (!layout_ptr) { |
|
|
|
@@ -2081,7 +2081,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
} |
|
|
|
|
|
|
|
// return -> tuple_getitem -> loss |
|
|
|
if (current_prim->name() == prim::kTupleGetitem) { |
|
|
|
if (current_prim->name() == prim::kTupleGetItem) { |
|
|
|
auto tuple_index = GetTupleGetItemIndex(pre_cnode); |
|
|
|
AnfNodePtr pre_pre_node = pre_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_pre_node); |
|
|
|
@@ -2338,7 +2338,7 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP |
|
|
|
} |
|
|
|
|
|
|
|
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>(); |
|
|
|
if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kTupleGetitem)) { |
|
|
|
if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kTupleGetItem)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
|