| @@ -90,6 +90,7 @@ class Cell(Cell_): | |||
| self._phase = 'train' | |||
| self._parameter_layout_dict = {} | |||
| self._parallel_parameter_name_list = () | |||
| self._parallel_parameter_merge_net_dict = {} | |||
| self._create_time = int(time.time() * 1e9) | |||
| self.phase_prefix = "" | |||
| self.parameter_broadcast_done = False | |||
| @@ -224,6 +225,16 @@ class Cell(Cell_): | |||
| raise TypeError("'parallel_parameter_name_list' must be list type.") | |||
| self._parallel_parameter_name_list = value | |||
| @property | |||
| def parallel_parameter_merge_net_dict(self): | |||
| return self._parallel_parameter_merge_net_dict | |||
| @parallel_parameter_merge_net_dict.setter | |||
| def parallel_parameter_merge_net_dict(self, value): | |||
| if not isinstance(value, dict): | |||
| raise TypeError("'parallel_parameter_merge_net_dict' must be dict type.") | |||
| self._parallel_parameter_merge_net_dict = value | |||
| def get_func_graph_proto(self): | |||
| """Return graph binary proto.""" | |||
| return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True) | |||
| @@ -382,46 +393,63 @@ class Cell(Cell_): | |||
| self._add_attr(key, value) | |||
| self._attr_synced = True | |||
| def __setattr__(self, name, value): | |||
| def _set_attr_for_parameter(self, name, value): | |||
| """Set attr for parameter.""" | |||
| cells = self.__dict__.get('_cells') | |||
| params = self.__dict__.get('_params') | |||
| if params is None: | |||
| raise AttributeError("Can not assign params before Cell.__init__() call.") | |||
| if name in self.__dict__: | |||
| if self.__dict__[name] is not None: | |||
| raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.") | |||
| del self.__dict__[name] | |||
| if cells and name in cells: | |||
| raise TypeError("The type of value should be Cell, but got Parameter.") | |||
| self.insert_param_to_cell(name, value) | |||
| def _set_attr_for_parameter_tuple(self, name, value): | |||
| """Set attr for parameter tuple.""" | |||
| params = self.__dict__.get('_params') | |||
| params_list = self.__dict__.get('_params_list') | |||
| tensor_list = self.__dict__.get('_tensor_list') | |||
| if isinstance(value, Parameter): | |||
| if params is None: | |||
| raise AttributeError("Can not assign params before Cell.__init__() call.") | |||
| if params is None: | |||
| raise AttributeError("Can not assign params before Cell.__init__() call.") | |||
| for item in value: | |||
| self.insert_param_to_cell(item.name, item, check_name=False) | |||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||
| if name in self.__dict__: | |||
| if self.__dict__[name] is not None: | |||
| raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.") | |||
| del self.__dict__[name] | |||
| if cells and name in cells: | |||
| raise TypeError("The type of value should be Cell, but got Parameter.") | |||
| self.insert_param_to_cell(name, value) | |||
| if name in params: | |||
| del params[name] | |||
| params_list[name] = value | |||
| else: | |||
| object.__setattr__(self, name, value) | |||
| def _set_attr_for_cell(self, name, value): | |||
| """Set attr for cell.""" | |||
| cells = self.__dict__.get('_cells') | |||
| params = self.__dict__.get('_params') | |||
| if cells is None: | |||
| raise AttributeError("Can not assign cells before Cell.__init__() call.") | |||
| if name in self.__dict__: | |||
| del self.__dict__[name] | |||
| if params and name in params: | |||
| raise TypeError("The type of value should be Parameter, but got Cell.") | |||
| if self._auto_prefix: | |||
| value.update_parameters_name(name + '.') | |||
| cells[name] = value | |||
| if hasattr(self, '_cell_init_args'): | |||
| self.cell_init_args += str({name: value}) | |||
| def __setattr__(self, name, value): | |||
| cells = self.__dict__.get('_cells') | |||
| params = self.__dict__.get('_params') | |||
| tensor_list = self.__dict__.get('_tensor_list') | |||
| if isinstance(value, Parameter): | |||
| self._set_attr_for_parameter(name, value) | |||
| elif isinstance(value, ParameterTuple): | |||
| if params is None: | |||
| raise AttributeError("Can not assign params before Cell.__init__() call.") | |||
| for item in value: | |||
| self.insert_param_to_cell(item.name, item, check_name=False) | |||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||
| if name in self.__dict__: | |||
| del self.__dict__[name] | |||
| if name in params: | |||
| del params[name] | |||
| params_list[name] = value | |||
| else: | |||
| object.__setattr__(self, name, value) | |||
| self._set_attr_for_parameter_tuple(name, value) | |||
| elif isinstance(value, Cell): | |||
| if cells is None: | |||
| raise AttributeError("Can not assign cells before Cell.__init__() call.") | |||
| if name in self.__dict__: | |||
| del self.__dict__[name] | |||
| if params and name in params: | |||
| raise TypeError("The type of value should be Parameter, but got Cell.") | |||
| if self._auto_prefix: | |||
| value.update_parameters_name(name + '.') | |||
| cells[name] = value | |||
| if hasattr(self, '_cell_init_args'): | |||
| self.cell_init_args += str({name: value}) | |||
| self._set_attr_for_cell(name, value) | |||
| elif params and name in params: | |||
| if isinstance(value, Tensor) and self._params[name] is not None: | |||
| self._params[name].set_data(value) | |||
| @@ -455,6 +455,12 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save): | |||
| uniform_split = layout[4] | |||
| opt_shard_group = layout[5] | |||
| allgather_net = None | |||
| if param_name in net.parallel_parameter_merge_net_dict: | |||
| allgather_net = net.parallel_parameter_merge_net_dict[param_name] | |||
| else: | |||
| logger.info("need to create allgather net for %s", param_name) | |||
| if integrated_save: | |||
| if uniform_split == 0: | |||
| raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.") | |||
| @@ -462,19 +468,25 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save): | |||
| # pipeline parallel need to be supported here later | |||
| for dim in tensor_map: | |||
| if dim != -1: | |||
| if opt_shard_group: | |||
| allgather_net = get_allgather_cell(opt_shard_group, True) | |||
| else: | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| if allgather_net is None: | |||
| if opt_shard_group: | |||
| allgather_net = get_allgather_cell(opt_shard_group, True) | |||
| else: | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| net.parallel_parameter_merge_net_dict[param_name] = allgather_net | |||
| param_data = allgather_net(param_data) | |||
| if field_size: | |||
| return _reshape_param_data_with_weight(param_data, dev_mat, field_size) | |||
| return _reshape_param_data(param_data, dev_mat, tensor_map) | |||
| if opt_shard_group: | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| if allgather_net is None: | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| net.parallel_parameter_merge_net_dict[param_name] = allgather_net | |||
| param_data = allgather_net(param_data) | |||
| elif opt_shard_group: | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| if allgather_net is None: | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| net.parallel_parameter_merge_net_dict[param_name] = allgather_net | |||
| param_data = allgather_net(param_data) | |||
| return param_data | |||
| @@ -0,0 +1,105 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.nn import Cell, Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train import Model | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint | |||
| from tests.dataset_mock import MindData | |||
| class Dataset(MindData): | |||
| def __init__(self, predict, label, length=3): | |||
| super(Dataset, self).__init__(size=length) | |||
| self.predict = predict | |||
| self.label = label | |||
| self.index = 0 | |||
| self.length = length | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.index >= self.length: | |||
| raise StopIteration | |||
| self.index += 1 | |||
| return self.predict, self.label | |||
| def reset(self): | |||
| self.index = 0 | |||
| class Net(Cell): | |||
| def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, mask=0): | |||
| super().__init__() | |||
| self.mul = P.Mul().shard(strategy1) | |||
| self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2) | |||
| self.weight = Parameter(weight, "w1") | |||
| self.mul2 = P.Mul() | |||
| self.weight2 = Parameter(w2, "w2") | |||
| self.begin = begin | |||
| self.end = end | |||
| self.strides = strides | |||
| def construct(self, x, b): | |||
| out = self.strided_slice( | |||
| self.weight, self.begin, self.end, self.strides) | |||
| out = self.mul(x, out) | |||
| out = self.mul2(out, self.weight2) | |||
| return out | |||
| _x = Tensor(np.ones([16, 64, 1]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32) | |||
| _w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) | |||
| def clean_all_ckpt_files(folder_path): | |||
| if os.path.exists(folder_path): | |||
| for file_name in os.listdir(folder_path): | |||
| if file_name.endswith('.ckpt') or file_name.endswith('.meta'): | |||
| os.remove(os.path.join(folder_path, file_name)) | |||
| def compile_net(net): | |||
| context.set_context(save_graphs=True) | |||
| learning_rate = 0.1 | |||
| momentum = 0.9 | |||
| epoch_size = 2 | |||
| dataset = Dataset(_x, _b) | |||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | |||
| model = Model(net, optimizer=opt) | |||
| ckpt_config = CheckpointConfig(keep_checkpoint_max=1) | |||
| ckpt_path = "./parallel_ckpt" | |||
| ckpt_cb = ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config) | |||
| model.train(epoch_size, dataset, dataset_sink_mode=False, callbacks=[ckpt_cb]) | |||
| assert len(model._train_network.parallel_parameter_merge_net_dict) == 4 | |||
| clean_all_ckpt_files(ckpt_path) | |||
| context.reset_auto_parallel_context() | |||
| def test_stridedslice_parameter(): | |||
| context.set_auto_parallel_context( | |||
| parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 4, 1), (1, 4, 2)) | |||
| strategy2 = ((1, 4, 2),) | |||
| net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), | |||
| strategy1, strategy2) | |||
| compile_net(net) | |||