|
|
|
@@ -218,6 +218,11 @@ def _check_similar_layout(tensor_layout1, tensor_layout2): |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def _check_same_layout(tensor_layout1, tensor_layout2): |
|
|
|
"""check if two tensor layouts are same""" |
|
|
|
return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1] |
|
|
|
|
|
|
|
|
|
|
|
def _remove_repeated_slices(tensor_layout): |
|
|
|
"""generate unrepeated tensor layout""" |
|
|
|
import copy |
|
|
|
@@ -236,9 +241,14 @@ def _infer_rank_list(train_map, predict_map=None): |
|
|
|
ret = {} |
|
|
|
for param_name in train_map: |
|
|
|
train_layout = train_map[param_name] |
|
|
|
new_train_layout = _remove_repeated_slices(train_layout) |
|
|
|
predict_layout = predict_map[param_name] |
|
|
|
train_dev_mat = train_layout[0] |
|
|
|
dev_num = np.array(train_dev_mat).prod() |
|
|
|
if _check_same_layout(train_layout, predict_layout): |
|
|
|
dev_rank = _get_global_rank() |
|
|
|
ret[param_name] = ([dev_rank], True) |
|
|
|
continue |
|
|
|
new_train_layout = _remove_repeated_slices(train_layout) |
|
|
|
array = np.arange(dev_num).reshape(train_dev_mat) |
|
|
|
index = () |
|
|
|
for i in new_train_layout[0]: |
|
|
|
@@ -248,16 +258,20 @@ def _infer_rank_list(train_map, predict_map=None): |
|
|
|
index = index + (slice(None),) |
|
|
|
rank_list = array[index].flatten() |
|
|
|
if not predict_map: |
|
|
|
ret[param_name] = rank_list |
|
|
|
ret[param_name] = (rank_list, False) |
|
|
|
continue |
|
|
|
if param_name not in predict_map: |
|
|
|
logger.warning("predict_map does not contain %s", param_name) |
|
|
|
continue |
|
|
|
predict_layout = predict_map[param_name] |
|
|
|
# optimization pass |
|
|
|
if _check_similar_layout(train_layout, predict_layout): |
|
|
|
dev_rank = _get_global_rank() |
|
|
|
ret[param_name] = [rank_list[dev_rank]] |
|
|
|
if len(rank_list) == 1: |
|
|
|
ret[param_name] = (rank_list, True) |
|
|
|
elif len(rank_list) == dev_num: |
|
|
|
dev_rank = _get_global_rank() |
|
|
|
ret[param_name] = ([rank_list[dev_rank]], True) |
|
|
|
else: |
|
|
|
ret[param_name] = (rank_list, False) |
|
|
|
else: |
|
|
|
ret[param_name] = rank_list |
|
|
|
ret[param_name] = (rank_list, False) |
|
|
|
return ret |