Merge pull request !1672 from yihuaijie/devtags/v0.5.0-beta
| @@ -48,6 +48,7 @@ ParallelContext::ParallelContext() { Reset(); } | |||
| void ParallelContext::Reset() { | |||
| mirror_mean_ = false; | |||
| full_batch_ = false; | |||
| cast_before_mirror_ = true; | |||
| loss_repeated_mean_ = true; | |||
| device_num_ = 1; | |||
| @@ -75,6 +76,8 @@ void ParallelContext::set_global_rank(int32_t global_rank) { | |||
| void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; } | |||
| void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } | |||
| void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } | |||
| void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | |||
| @@ -55,6 +55,9 @@ class ParallelContext { | |||
| void set_mirror_mean(bool mirror_mean); | |||
| bool mirror_mean() const { return mirror_mean_; } | |||
| void set_full_batch(bool full_batch); | |||
| bool full_batch() const { return full_batch_; } | |||
| void set_cast_before_mirror(bool cast_before_mirror); | |||
| bool cast_before_mirror() const { return cast_before_mirror_; } | |||
| @@ -103,6 +106,7 @@ class ParallelContext { | |||
| ParallelContext(); | |||
| static std::shared_ptr<ParallelContext> inst_context_; | |||
| bool mirror_mean_; | |||
| bool full_batch_; | |||
| bool cast_before_mirror_; | |||
| bool loss_repeated_mean_; | |||
| int32_t device_num_; | |||
| @@ -24,15 +24,23 @@ | |||
| #include "ir/value.h" | |||
| #include "parallel/device_matrix.h" | |||
| #include "parallel/strategy.h" | |||
| #include "parallel/context.h" | |||
| #include "parallel/tensor_layout/tensor_redistribution.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status GetNextInfo::InferTensorMap() { | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||
| for (auto shp : shapes_) { | |||
| TensorMap out_tensor_map; | |||
| for (size_t i = 0; i < shp.size(); ++i) { | |||
| out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); | |||
| if (full_batch) { | |||
| out_tensor_map.push_back(MAP_NONE); | |||
| } else { | |||
| out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); | |||
| } | |||
| } | |||
| outputs_tensor_map_.push_back(out_tensor_map); | |||
| } | |||
| @@ -190,6 +198,9 @@ Status GetNextInfo::GetAttrs() { | |||
| } | |||
| Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||
| Shapes out_shapes = outputs_shape_; | |||
| for (size_t i = 0; i < out_shapes.size(); ++i) { | |||
| if (dev_num_ <= 0) { | |||
| @@ -200,7 +211,9 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { | |||
| MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; | |||
| return FAILED; | |||
| } | |||
| out_shapes[i][0] = out_shapes[i][0] / dev_num_; | |||
| if (!full_batch) { | |||
| out_shapes[i][0] = out_shapes[i][0] / dev_num_; | |||
| } | |||
| } | |||
| ValuePtr new_shapes = MakeValue(out_shapes); | |||
| Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]); | |||
| @@ -23,6 +23,7 @@ | |||
| #include "parallel/device_manager.h" | |||
| #include "parallel/device_matrix.h" | |||
| #include "parallel/step_parallel.h" | |||
| #include "parallel/context.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -93,59 +94,21 @@ Status VirtualDatasetInfo::InferDevMatrixShape() { | |||
| return SUCCESS; | |||
| } | |||
| Status VirtualDatasetInfo::InferMirrorOps() { | |||
| mirror_ops_.clear(); | |||
| int32_t stage = strategy_->GetInputStage(); | |||
| CheckGlobalDeviceManager(); | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage); | |||
| if (dev_list.empty()) { | |||
| MS_LOG(ERROR) << name_ << ": The current stage is empty!"; | |||
| return Status::FAILED; | |||
| } | |||
| if (dev_list.size() == 1) { | |||
| MS_LOG(INFO) << name_ << ": No need mirror ops."; | |||
| return Status::SUCCESS; | |||
| } | |||
| OperatorName operator_name = BROADCAST; | |||
| ValuePtr attr0_value = MakeValue(dev_list.front()); | |||
| std::vector<Group> group_list; | |||
| if (CreateGroupByDim(dev_matrix_shape_.size() - 1, &group_list) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Infer mirror ops, create group failed."; | |||
| return FAILED; | |||
| } else if (group_list.empty()) { | |||
| MS_LOG(INFO) << name_ << ": No need mirror ops."; | |||
| return SUCCESS; | |||
| } | |||
| std::string group = group_list[0].name(); | |||
| ValuePtr attr1_value = MakeValue(group); | |||
| Attr attr0 = std::make_pair(SRC, attr0_value); | |||
| Attr attr1 = std::make_pair(GROUP, attr1_value); | |||
| OperatorAttrs operator_attrs = {attr0, attr1}; | |||
| OperatorParams operator_param; | |||
| OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); | |||
| Operator op = std::make_pair(operator_name, operator_args); | |||
| OperatorVector op_vector = {op}; | |||
| size_t size = inputs_shape_.size(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| mirror_ops_.push_back(op_vector); | |||
| } | |||
| mirror_ops_.clear(); | |||
| return SUCCESS; | |||
| } | |||
| Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; } | |||
| Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } | |||
| Status VirtualDatasetInfo::InferTensorMap() { | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||
| for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { | |||
| std::vector<int32_t> tensor_map_index; | |||
| tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); | |||
| if (full_batch) { | |||
| tensor_map_index.push_back(MAP_NONE); | |||
| } else { | |||
| tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); | |||
| } | |||
| for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { | |||
| tensor_map_index.push_back(MAP_NONE); | |||
| } | |||
| @@ -213,6 +176,10 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||
| size_t total_dev_num; | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": GetAttrs failed"; | |||
| return FAILED; | |||
| @@ -220,7 +187,11 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { | |||
| CheckGlobalDeviceManager(); | |||
| is_auto_parallel_ = true; | |||
| size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||
| if (full_batch) { | |||
| total_dev_num = 1; | |||
| } else { | |||
| total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||
| } | |||
| StrategyPtr sp; | |||
| std::vector<Dimensions> strategy; | |||
| for (auto &shape : inputs_shape_) { | |||
| @@ -232,10 +203,18 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { | |||
| sp = std::make_shared<Strategy>(stage_id, strategy); | |||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||
| MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; | |||
| if (full_batch) { | |||
| MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy."; | |||
| } else { | |||
| MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; | |||
| } | |||
| PrintStrategy(sp); | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; | |||
| if (full_batch) { | |||
| MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed."; | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; | |||
| } | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| @@ -1375,11 +1375,19 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == VIRTUAL_DATA_SET) { | |||
| CheckGlobalDeviceManager(); | |||
| int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); | |||
| int32_t dev_num; | |||
| if (full_batch) { | |||
| dev_num = 1; | |||
| } else { | |||
| dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); | |||
| } | |||
| auto attrs_temp = prim->attrs(); | |||
| std::vector<Shapes> shape_list = ExtractShape(node); | |||
| if (shape_list.empty()) { | |||
| @@ -187,6 +187,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Set strategy checkpoint save file.") | |||
| .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") | |||
| .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") | |||
| .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") | |||
| .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") | |||
| .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); | |||
| (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") | |||
| @@ -367,7 +367,8 @@ def _context(): | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, | |||
| full_batch=bool) | |||
| def set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| @@ -404,6 +405,7 @@ def set_auto_parallel_context(**kwargs): | |||
| broadcast. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -225,6 +225,21 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_strategy_ckpt_load_file() | |||
| def set_full_batch(self, full_batch): | |||
| """ | |||
| Set whether load full batch on each device. | |||
| Args: | |||
| full_batch (bool): True if load full batch on each device. | |||
| """ | |||
| self.check_context_handle() | |||
| self._context_handle.set_full_batch(full_batch) | |||
| def get_full_batch(self): | |||
| """Get whether load full batch on each device.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_full_batch() | |||
| def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): | |||
| """ | |||
| Set strategy checkpoint save path. | |||
| @@ -409,7 +424,8 @@ _set_auto_parallel_context_func_map = { | |||
| "parallel_mode": auto_parallel_context().set_parallel_mode, | |||
| "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file} | |||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().set_full_batch} | |||
| _get_auto_parallel_context_func_map = { | |||
| @@ -421,12 +437,13 @@ _get_auto_parallel_context_func_map = { | |||
| "parallel_mode": auto_parallel_context().get_parallel_mode, | |||
| "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file} | |||
| "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().get_full_batch} | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, | |||
| loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool, | |||
| strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) | |||
| strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| @@ -459,6 +476,7 @@ def _set_auto_parallel_context(**kwargs): | |||
| broadcast. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -20,10 +20,26 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| def _get_parallel_mode(): | |||
| """Get parallel mode.""" | |||
| return auto_parallel_context().get_parallel_mode() | |||
| def _get_full_batch(): | |||
| """Get whether to use full_batch.""" | |||
| return auto_parallel_context().get_full_batch() | |||
| def _need_to_full(): | |||
| """Check whether to convert input to full shape or tensor.""" | |||
| parallel_mode = _get_parallel_mode() | |||
| full_batch = _get_full_batch() | |||
| need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel")) | |||
| and (not full_batch)) | |||
| return need | |||
| def _get_mirror_mean(): | |||
| """Get if using mirror_mean.""" | |||
| return auto_parallel_context().get_mirror_mean() | |||
| @@ -17,11 +17,10 @@ import math | |||
| from mindspore._checkparam import check_bool | |||
| from .. import context | |||
| from .parallel_utils import ParallelMode | |||
| from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ | |||
| _construct_tensor_list, _to_full_shapes, _to_full_tensor | |||
| from ..nn.wrap import GetNextSingleOp | |||
| from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode | |||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full | |||
| class DatasetHelper: | |||
| @@ -118,10 +117,10 @@ class _DatasetIterMSLoopSink(_DatasetIter): | |||
| def __init__(self, dataset): | |||
| super(_DatasetIterMSLoopSink, self).__init__(dataset) | |||
| self.loop_count = self.get_loop_count(dataset) | |||
| # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to | |||
| # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number | |||
| # times the batch dimension of tensors for run. Now only support LoopSink. | |||
| if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
| # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, | |||
| # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for | |||
| # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. | |||
| if _need_to_full(): | |||
| device_num = _get_device_num() | |||
| self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) | |||
| @@ -146,10 +145,8 @@ class _DatasetIterGE(_DatasetIter): | |||
| def __init__(self, dataset): | |||
| super(_DatasetIterGE, self).__init__(dataset) | |||
| self.loop_count = self.get_loop_count(dataset) | |||
| parallel_mode = _get_parallel_mode() | |||
| self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| batch_expand_num = 1 | |||
| if self.need_to_full: | |||
| if _need_to_full(): | |||
| batch_expand_num = _get_device_num() | |||
| tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) | |||
| @@ -170,9 +167,6 @@ class _DatasetIterFeed: | |||
| self.loop_count = dataset.get_dataset_size() | |||
| self.ind = 0 | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| def __iter__(self): | |||
| if self.repeat_ind % self.repeat_count == 0: | |||
| self.iter = self.dataset.__iter__() | |||
| @@ -186,6 +180,6 @@ class _DatasetIterFeed: | |||
| raise StopIteration() | |||
| self.ind += 1 | |||
| data = self.iter.__next__() | |||
| if self.need_to_full: | |||
| if _need_to_full(): | |||
| return _to_full_tensor(data, self.device_num, self.global_rank) | |||
| return _to_tensor(data) | |||
| @@ -22,6 +22,7 @@ def argparse_init(): | |||
| parser = argparse.ArgumentParser(description='WideDeep') | |||
| parser.add_argument("--data_path", type=str, default="./test_raw_data/") | |||
| parser.add_argument("--epochs", type=int, default=15) | |||
| parser.add_argument("--full_batch", type=bool, default=False) | |||
| parser.add_argument("--batch_size", type=int, default=16000) | |||
| parser.add_argument("--eval_batch_size", type=int, default=16000) | |||
| parser.add_argument("--field_size", type=int, default=39) | |||
| @@ -44,6 +45,7 @@ class WideDeepConfig(): | |||
| """ | |||
| def __init__(self): | |||
| self.data_path = "./test_raw_data/" | |||
| self.full_batch = False | |||
| self.epochs = 15 | |||
| self.batch_size = 16000 | |||
| self.eval_batch_size = 16000 | |||
| @@ -72,6 +74,7 @@ class WideDeepConfig(): | |||
| args, _ = parser.parse_known_args() | |||
| self.data_path = args.data_path | |||
| self.epochs = args.epochs | |||
| self.full_batch = args.full_batch | |||
| self.batch_size = args.batch_size | |||
| self.eval_batch_size = args.eval_batch_size | |||
| self.field_size = args.field_size | |||
| @@ -17,8 +17,10 @@ | |||
| Area under cure metric | |||
| """ | |||
| from mindspore.nn.metrics import Metric | |||
| from sklearn.metrics import roc_auc_score | |||
| from mindspore import context | |||
| from mindspore.nn.metrics import Metric | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| class AUCMetric(Metric): | |||
| """ | |||
| @@ -28,6 +30,7 @@ class AUCMetric(Metric): | |||
| def __init__(self): | |||
| super(AUCMetric, self).__init__() | |||
| self.clear() | |||
| self.full_batch = context.get_auto_parallel_context("full_batch") | |||
| def clear(self): | |||
| """Clear the internal evaluation result.""" | |||
| @@ -35,10 +38,17 @@ class AUCMetric(Metric): | |||
| self.pred_probs = [] | |||
| def update(self, *inputs): # inputs | |||
| all_predict = inputs[1].asnumpy() # predict | |||
| all_label = inputs[2].asnumpy() # label | |||
| self.true_labels.extend(all_label.flatten().tolist()) | |||
| self.pred_probs.extend(all_predict.flatten().tolist()) | |||
| """Update list of predicts and labels.""" | |||
| all_predict = inputs[1].asnumpy().flatten().tolist() # predict | |||
| all_label = inputs[2].asnumpy().flatten().tolist() # label | |||
| self.pred_probs.extend(all_predict) | |||
| if self.full_batch: | |||
| rank_id = get_rank() | |||
| group_size = get_group_size() | |||
| gap = len(all_label) // group_size | |||
| self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap]) | |||
| else: | |||
| self.true_labels.extend(all_label) | |||
| def eval(self): | |||
| if len(self.true_labels) != len(self.pred_probs): | |||
| @@ -17,6 +17,7 @@ | |||
| import os | |||
| import sys | |||
| import mindspore.dataset.engine as de | |||
| from mindspore import Model, context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | |||
| from mindspore.train import ParallelMode | |||
| @@ -79,10 +80,18 @@ def test_train_eval(): | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| if config.full_batch: | |||
| context.set_auto_parallel_context(full_batch=True) | |||
| de.config.set_seed(1) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, | |||
| batch_size=batch_size*get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, | |||
| batch_size=batch_size*get_group_size()) | |||
| else: | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, | |||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| @@ -0,0 +1,89 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.parallel._utils import _reset_op_id | |||
| from mindspore.train import Model, ParallelMode | |||
| from tests.dataset_mock import MindData | |||
| class Dataset(MindData): | |||
| def __init__(self, predict, label, length=3): | |||
| super(Dataset, self).__init__(size=length) | |||
| self.predict = predict | |||
| self.label = label | |||
| self.index = 0 | |||
| self.length = length | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.index >= self.length: | |||
| raise StopIteration | |||
| self.index += 1 | |||
| return self.predict, self.label | |||
| def reset(self): | |||
| self.index = 0 | |||
| class AllToAllNet(nn.Cell): | |||
| def __init__(self, strategy1): | |||
| super(AllToAllNet, self).__init__() | |||
| self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8))) | |||
| self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight") | |||
| self.transpose1 = P.Transpose().set_strategy(strategy1) | |||
| def construct(self, x): | |||
| x = self.matmul(x, self.matmul_weight) | |||
| x = self.transpose1(x, (1, 0)) | |||
| return x | |||
| def all_to_all_net(strategy1): | |||
| return AllToAllNet(strategy1=strategy1) | |||
| def all_to_all_common(strategy1): | |||
| learning_rate = 0.1 | |||
| momentum = 0.9 | |||
| epoch_size = 2 | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8, full_batch=True) | |||
| predict = Tensor(np.ones([256, 128]), dtype=ms.float32) | |||
| label = Tensor(np.ones([256]), dtype=ms.int32) | |||
| dataset = Dataset(predict, label, 2) | |||
| net = all_to_all_net(strategy1) | |||
| loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||
| loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1))) | |||
| loss.one_hot.set_strategy(((8, 1), (), ())) | |||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | |||
| model = Model(net, loss, opt) | |||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | |||
| def test_all_to_all(): | |||
| strategy1 = ((8, 1),) | |||
| _reset_op_id() | |||
| all_to_all_common(strategy1) | |||