|
|
|
@@ -225,15 +225,6 @@ def load_param_into_net(net, parameter_dict): |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
logger.info("Execute load parameter into net process.") |
|
|
|
for name in parameter_dict: |
|
|
|
for _, param in net.parameters_and_names(): |
|
|
|
if name == param.name and param.layerwise_parallel: |
|
|
|
# layerwise parallel parameter data loaded from checkpoint file, |
|
|
|
# was a complete(merged) data, need to be splited |
|
|
|
new_param = parameter_dict[param.name] |
|
|
|
_load_tensor_for_layerwise(new_param, param) |
|
|
|
break |
|
|
|
|
|
|
|
param_not_load = [] |
|
|
|
for _, param in net.parameters_and_names(): |
|
|
|
if param.name in parameter_dict: |
|
|
|
@@ -363,34 +354,6 @@ def _get_merged_param_data(net, param_name, param_data): |
|
|
|
return param_data |
|
|
|
|
|
|
|
|
|
|
|
def _load_tensor_for_layerwise(new_param, old_param): |
|
|
|
""" |
|
|
|
Replaces parameters with sliced tensors by layerwise parallel strategies. |
|
|
|
|
|
|
|
Args: |
|
|
|
new_param (Parameter): The new layerwise parallel parameter, will be loaded into net. |
|
|
|
old_param(Parameter): The current parameter in the net. |
|
|
|
""" |
|
|
|
if not isinstance(new_param.data, Tensor) or not isinstance(old_param.data, Tensor): |
|
|
|
logger.error("Failed to combine the net and the parameters.") |
|
|
|
msg = ("layerwise parallel parameter should be a Tensor, but got {}.".format(type(new_param.data))) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if old_param.data.shape() == new_param.data.shape(): |
|
|
|
return |
|
|
|
|
|
|
|
from mindspore.parallel._tensor import _load_tensor |
|
|
|
from mindspore.communication.management import get_group_size |
|
|
|
dev_mat = [get_group_size()] |
|
|
|
shape = new_param.data.shape() |
|
|
|
for x in range(len(shape)): # dim 0 set 0, others set -1 |
|
|
|
if x: |
|
|
|
tensor_map.append(-1) |
|
|
|
|
|
|
|
new_tensor = _load_tensor(new_param.data, dev_mat, tensor_map) |
|
|
|
new_param.set_parameter_data(new_tensor) |
|
|
|
|
|
|
|
|
|
|
|
def _fill_param_into_net(net, parameter_list): |
|
|
|
""" |
|
|
|
Fills parameter_list into net. |
|
|
|
|