|
|
|
@@ -200,9 +200,11 @@ class HcclParser: |
|
|
|
info_set = set() |
|
|
|
for item in communication_info: |
|
|
|
# index0:step_num,index1:communication_cost,index2:communication_wait_cost,index3:link_info |
|
|
|
info_set.add(item[0]) |
|
|
|
if item[0].isdigit(): |
|
|
|
info_set.add(int(item[0])) |
|
|
|
info_set = sorted(info_set) |
|
|
|
for item in info_set: |
|
|
|
item = str(item) |
|
|
|
step_communication_info = [info for info in communication_info if info[0] == item] |
|
|
|
step_communication_cost = sum([i[1] for i in step_communication_info]) |
|
|
|
step_communication_wait_cost = sum([i[2] for i in step_communication_info]) |
|
|
|
@@ -250,6 +252,15 @@ class HcclParser: |
|
|
|
dst_rank = item.get("args").get("dst rank") |
|
|
|
if src_rank is None or dst_rank is None: |
|
|
|
continue |
|
|
|
|
|
|
|
# When the SDMA operation is in the card, |
|
|
|
# the source card or destination card is 0xffffffff, and it needs to be converted to localrank. |
|
|
|
if int(src_rank) == int('0xffffffff', 16): |
|
|
|
src_rank = dst_rank |
|
|
|
|
|
|
|
if int(dst_rank) == int('0xffffffff', 16): |
|
|
|
dst_rank = src_rank |
|
|
|
|
|
|
|
if item.get("args").get("transport type") == CommunicationInfo.LOCAL.value: |
|
|
|
item["args"]["src rank"] = dst_rank |
|
|
|
item["args"]["dst rank"] = src_rank |
|
|
|
|