|
|
|
@@ -379,16 +379,15 @@ class _Executor: |
|
|
|
auto_split_param_names = (param_name for param_name in auto_split_params) |
|
|
|
return auto_split_param_names |
|
|
|
|
|
|
|
def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase): |
|
|
|
def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase): |
|
|
|
"""Build broadcast graph.""" |
|
|
|
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell |
|
|
|
|
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params) |
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params_dict.values()) |
|
|
|
_broadcast_net.phase = broadcast_phase |
|
|
|
broadcasted_params = _broadcast_net() |
|
|
|
parameters_broadcast_dict = obj.parameters_broadcast_dict() |
|
|
|
for param_name, param in zip(parameters_broadcast_dict, broadcasted_params): |
|
|
|
parameters_broadcast_dict[param_name].set_data(param) |
|
|
|
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params): |
|
|
|
broadcast_params_dict[param_name].set_data(param) |
|
|
|
|
|
|
|
def _set_dataset_mode(self, args_list): |
|
|
|
"""set dataset mode.""" |
|
|
|
@@ -476,10 +475,15 @@ class _Executor: |
|
|
|
auto_split_param_names = [] |
|
|
|
if auto_parallel_mode: |
|
|
|
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict) |
|
|
|
broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if |
|
|
|
param_name not in auto_split_param_names] |
|
|
|
broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time) |
|
|
|
self._build_broadcast_graph(obj, broadcast_params, broadcast_phase) |
|
|
|
|
|
|
|
broadcast_params_dict = obj.parameters_broadcast_dict() |
|
|
|
if auto_split_param_names and broadcast_params_dict: |
|
|
|
broadcast_params_dict = OrderedDict() |
|
|
|
for param_name, param in obj.parameters_broadcast_dict().items(): |
|
|
|
if param_name not in auto_split_param_names: |
|
|
|
broadcast_params_dict[param_name] = param |
|
|
|
broadcast_phase = "_broadcast_subgraph" + "." + str(obj.create_time) |
|
|
|
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase) |
|
|
|
self.compile_cache[phase] = broadcast_phase |
|
|
|
|
|
|
|
return phase, True |
|
|
|
|