|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@@ -241,7 +241,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, |
|
|
|
each_param = {"name": key} |
|
|
|
param_data = Tensor(value.data) |
|
|
|
|
|
|
|
# in automatic model parallel scenario, some parameters were spliteds to all the devices, |
|
|
|
# in automatic model parallel scenario, some parameters were split to all the devices, |
|
|
|
# which should be combined before saving |
|
|
|
if key in save_obj.parameter_layout_dict: |
|
|
|
param_data = _get_merged_param_data(save_obj, key, param_data, integrated_save) |
|
|
|
@@ -1365,9 +1365,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= |
|
|
|
param_dict[param.name] = split_param |
|
|
|
|
|
|
|
if param_not_in_strategy: |
|
|
|
logger.warning("{} parameters in network are not in the sclice strategy.".format(param_not_in_strategy)) |
|
|
|
logger.warning("{} parameters in network are not in the slice strategy.".format(param_not_in_strategy)) |
|
|
|
if param_not_in_ckpt: |
|
|
|
logger.warning("{} parameters in sclice strategy but not in the checkpoint file.".format(param_not_in_ckpt)) |
|
|
|
logger.warning("{} parameters in slice strategy but not in the checkpoint file.".format(param_not_in_ckpt)) |
|
|
|
|
|
|
|
load_param_into_net(network, param_dict) |
|
|
|
|
|
|
|
|