From: @yao_yf Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -27,7 +27,7 @@ void DynamicShapeKernel::Execute() { | |||
| } | |||
| auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0); | |||
| auto output_shape = std::vector<int64_t>(SizeToLong(prev_output_shape.size())); | |||
| std::vector<int64_t> output_shape = {SizeToLong(prev_output_shape.size())}; | |||
| auto output_type = TypeId::kNumberTypeInt64; | |||
| @@ -62,7 +62,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An | |||
| continue; | |||
| } | |||
| } | |||
| if (AnfAlgo::IsNodeDynamicShape(cnode) && | |||
| if (AnfAlgo::IsDynamicShape(cnode) && | |||
| DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) { | |||
| MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); | |||
| continue; | |||
| @@ -569,7 +569,6 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kerne | |||
| void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { | |||
| MS_LOG(INFO) << "Start!"; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| opt::RemoveNopNode(kernel_graph); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->AssignMemory(kernel_graph); | |||
| @@ -471,7 +471,21 @@ bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { return false; } | |||
| DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| if (!anf_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "anf_node should be a cnode"; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (opt::IsNopNode(cnode)) { | |||
| size_t kNopNodeInputSize = 2; | |||
| size_t kNopNodeRealInputIndex = 1; | |||
| if (cnode->size() != kNopNodeInputSize) { | |||
| MS_LOG(EXCEPTION) << cnode->fullname_with_scope() << " has invalid input size: " << cnode->size(); | |||
| } | |||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index); | |||
| return PreAssignCNodeMemory(cnode->input(kNopNodeRealInputIndex), input_node_with_index.second); | |||
| } | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | |||
| if (output_sizes.size() <= index) { | |||
| MS_LOG(EXCEPTION) << "Previous node output size < node index"; | |||
| @@ -126,6 +126,7 @@ class Parameter(MetaTensor_): | |||
| self.is_param_ps = False | |||
| self._cast_type = None | |||
| self.init_in_server = False | |||
| self._unique = False | |||
| self.is_in_parallel = _is_in_parallel_mode() | |||
| @staticmethod | |||
| @@ -238,6 +239,15 @@ class Parameter(MetaTensor_): | |||
| def sliced(self, sliced_): | |||
| self._sliced = sliced_ | |||
| @property | |||
| def unique(self): | |||
| """whether the parameter is already unique or not.""" | |||
| return self._unique | |||
| @unique.setter | |||
| def unique(self, unique_): | |||
| self._unique = unique_ | |||
| @property | |||
| def is_init(self): | |||
| """ | |||
| @@ -561,6 +561,17 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr | |||
| result_shp.push_back(input_shp[idx]); | |||
| indices.insert(idx); | |||
| } | |||
| ShapeVector max_shp; | |||
| ShapeVector min_shp; | |||
| if (input->shape()->max_shape().size() == input_shp.size() && | |||
| input->shape()->min_shape().size() == input_shp.size()) { | |||
| for (size_t i = 0; i < perm_vec.size(); i++) { | |||
| size_t idx = static_cast<size_t>(perm_vec[i]); | |||
| max_shp.push_back(input->shape()->max_shape()[idx]); | |||
| min_shp.push_back(input->shape()->min_shape()[idx]); | |||
| } | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp)); | |||
| } | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp)); | |||
| } | |||
| @@ -405,10 +405,9 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr | |||
| if (tmp_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "shape size is 0"; | |||
| } | |||
| if (tmp_shape[0] % rank_size != 0) { | |||
| MS_LOG(EXCEPTION) << "first dimension of x should be divided by rank_size"; | |||
| if (tmp_shape[0] > 0) { | |||
| tmp_shape[0] = tmp_shape[0] * rank_size; | |||
| } | |||
| tmp_shape[0] = tmp_shape[0] / rank_size; | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape)); | |||
| } | |||
| @@ -150,6 +150,7 @@ class EmbeddingLookup(Cell): | |||
| max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 | |||
| or None. Default: None | |||
| sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | |||
| Inputs: | |||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||
| Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table, | |||
| @@ -191,6 +192,12 @@ class EmbeddingLookup(Cell): | |||
| name='embedding_table') | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.forward_unique = False | |||
| self.gather_revert = P.GatherV2() | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| indices_shape_size = 2 | |||
| if slice_mode == "field_slice" and is_auto_parallel: | |||
| if not manual_shapes: | |||
| raise ValueError("in slice field mode, the manual_shapes should not be none") | |||
| @@ -203,18 +210,32 @@ class EmbeddingLookup(Cell): | |||
| self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) | |||
| self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | |||
| elif slice_mode == "table_row_slice" and is_auto_parallel: | |||
| self.gatherv2.shard(((get_group_size(), 1), (1, 1))) | |||
| self.embeddinglookup.shard(((get_group_size(), 1), (1, 1))) | |||
| if target == 'DEVICE': | |||
| indices_shape_size = 1 | |||
| self.gather_revert.shard(((1, 1), (1,))) | |||
| self.forward_unique = True | |||
| indices_strategy = (1,)*indices_shape_size | |||
| self.gatherv2.shard(((get_group_size(), 1), indices_strategy)) | |||
| self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy)) | |||
| elif slice_mode == "table_column_slice" and is_auto_parallel: | |||
| self.gatherv2.shard(((1, get_group_size()), (1, 1))) | |||
| self.embeddinglookup.shard(((1, get_group_size()), (1, 1))) | |||
| if target == 'DEVICE': | |||
| indices_shape_size = 1 | |||
| self.gather_revert.shard(((1, get_group_size()), (1,))) | |||
| self.forward_unique = True | |||
| indices_strategy = (1,)*indices_shape_size | |||
| self.gatherv2.shard(((1, get_group_size()), indices_strategy)) | |||
| self.embeddinglookup.shard(((1, get_group_size()), indices_strategy)) | |||
| elif slice_mode == "batch_slice" and is_auto_parallel: | |||
| self.gatherv2.shard(((1, 1), (get_group_size(), 1))) | |||
| self.embeddinglookup.shard(((1, 1), (get_group_size(), 1))) | |||
| indices_strategy = [get_group_size()] | |||
| indices_strategy.extend([1]*(indices_shape_size - 1)) | |||
| indices_strategy = tuple(indices_strategy) | |||
| self.gatherv2.shard(((1, 1), indices_strategy)) | |||
| self.embeddinglookup.shard(((1, 1), indices_strategy)) | |||
| else: | |||
| if is_auto_parallel: | |||
| raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " | |||
| + str(slice_mode)) | |||
| self.embedding_table.unique = self.forward_unique | |||
| self.max_norm = max_norm | |||
| if self.max_norm is not None: | |||
| self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) | |||
| @@ -224,7 +245,15 @@ class EmbeddingLookup(Cell): | |||
| if self.target == "CPU": | |||
| out = self.embeddinglookup(self.embedding_table, indices, 0) | |||
| else: | |||
| out = self.gatherv2(self.embedding_table, indices, 0) | |||
| if self.forward_unique: | |||
| shp = self.shape(indices) + (self.embedding_size,) | |||
| indices_flatten = self.reshape(indices, (-1,)) | |||
| unique_id, unique_idx = self.unique(indices_flatten) | |||
| weight_unique = self.gatherv2(unique_id) | |||
| weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) | |||
| out = self.reshape(weight_flatten, shp) | |||
| else: | |||
| out = self.gatherv2(self.embedding_table, indices, 0) | |||
| if self.max_norm is not None: | |||
| axis = _make_axis_range(F.rank(indices), F.rank(out)) | |||
| clip_by_norm = ClipByNorm(axis) | |||
| @@ -144,6 +144,11 @@ class Optimizer(Cell): | |||
| decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name | |||
| self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | |||
| self.exec_weight_decay = self.weight_decay > 0 | |||
| # when a parameter has been unique, there is no need do another unique in optimizer. | |||
| for param in self.parameters: | |||
| if param.unique: | |||
| self._unique = False | |||
| break | |||
| ps_filter = lambda x: x.is_param_ps | |||
| self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| @@ -67,6 +67,7 @@ from .mul_ds import _mul_ds_tbe | |||
| from .real_div import _real_div_tbe | |||
| from .real_div_ds import _real_div_ds_tbe | |||
| from .relu import _relu_tbe | |||
| from .relu_ds import _relu_ds_tbe | |||
| from .relu_grad import _relu_grad_tbe | |||
| from .relu6 import _relu6_tbe | |||
| from .relu6_grad import _relu6_grad_tbe | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ReLU op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| relu_op_info = TBERegOp("ReLU") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("relu.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("relu") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .dtype_format(DataType.I8_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(relu_op_info) | |||
| def _relu_ds_tbe(): | |||
| """Relu TBE register""" | |||
| return | |||
| @@ -274,7 +274,8 @@ class AllGather(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| validator.check_positive_int(len(x_shape), "x shape", self.name) | |||
| x_shape[0] = x_shape[0] * self.rank_size | |||
| if x_shape[0] > 0: | |||
| x_shape[0] = x_shape[0] * self.rank_size | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -324,7 +325,8 @@ class _HostAllGather(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| validator.check_positive_int(len(x_shape), "x shape", self.name) | |||
| x_shape[0] = x_shape[0] * self.group_size | |||
| if x_shape[0] > 0: | |||
| x_shape[0] = x_shape[0] * self.group_size | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -721,7 +721,7 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): | |||
| if field_size > 0: | |||
| from mindspore.parallel._tensor import _reshape_param_data_with_weight | |||
| merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, [field_size]) | |||
| merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size) | |||
| else: | |||
| from mindspore.parallel._tensor import _reshape_param_data | |||
| @@ -43,7 +43,7 @@ do | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 & | |||
| elif [ $MODE == "field_slice_host_device_mix" ]; then | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 & | |||
| elif [ $MODE == "backward_unique" ]; then | |||
| elif [ $MODE == "forward_unique" ]; then | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --sparse=1 >train_deep$i.log 2>&1 & | |||
| else | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 & | |||
| @@ -38,7 +38,7 @@ do | |||
| user=$(get_node_user ${cluster_config_path} ${node}) | |||
| passwd=$(get_node_passwd ${cluster_config_path} ${node}) | |||
| echo "------------------${user}@${node}---------------------" | |||
| if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "backward_unique" ]; then | |||
| if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "forward_unique" ]; then | |||
| ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}" | |||
| else | |||
| echo "[ERROR] mode is wrong" | |||
| @@ -88,7 +88,7 @@ class EvalCallBack(Callback): | |||
| Args: | |||
| print_per_step (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False): | |||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): | |||
| super(EvalCallBack, self).__init__() | |||
| if not isinstance(print_per_step, int) or print_per_step < 0: | |||
| raise ValueError("print_per_step must be int and >= 0.") | |||
| @@ -99,7 +99,7 @@ class EvalCallBack(Callback): | |||
| self.aucMetric.clear() | |||
| self.eval_file_name = config.eval_file_name | |||
| self.eval_values = [] | |||
| self.host_device_mix = host_device_mix | |||
| self.sparse = config.sparse | |||
| self.config = config | |||
| def epoch_end(self, run_context): | |||
| @@ -116,7 +116,7 @@ class EvalCallBack(Callback): | |||
| ParallelMode.DATA_PARALLEL): | |||
| rank_id = get_rank() | |||
| start_time = time.time() | |||
| out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix)) | |||
| out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse)) | |||
| end_time = time.time() | |||
| eval_time = int(end_time - start_time) | |||
| @@ -48,6 +48,7 @@ def argparse_init(): | |||
| parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not") | |||
| parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not") | |||
| parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not") | |||
| parser.add_argument("--deep_table_slice_mode", type=str, default="column_slice", help="column_slice/row_slice") | |||
| return parser | |||
| @@ -86,6 +87,7 @@ class WideDeepConfig(): | |||
| self.field_slice = False | |||
| self.manual_shape = None | |||
| self.sparse = False | |||
| self.deep_table_slice_mode = "column_slice" | |||
| def argparse_init(self): | |||
| """ | |||
| @@ -121,5 +123,6 @@ class WideDeepConfig(): | |||
| self.parameter_server = args.parameter_server | |||
| self.field_slice = bool(args.field_slice) | |||
| self.sparse = bool(args.sparse) | |||
| self.deep_table_slice_mode = args.deep_table_slice_mode | |||
| if self.host_device_mix == 1: | |||
| self.sparse = True | |||
| @@ -198,19 +198,29 @@ class WideDeepModel(nn.Cell): | |||
| self.tile = P.Tile() | |||
| self.concat = P.Concat(axis=1) | |||
| self.cast = P.Cast() | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| self.wide_gatherv2 = P.GatherV2() | |||
| self.deep_gatherv2 = P.GatherV2() | |||
| if is_auto_parallel and sparse and not is_field_slice: | |||
| self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) | |||
| self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) | |||
| self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | |||
| target = 'DEVICE' | |||
| if host_device_mix: | |||
| target = 'CPU' | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) | |||
| self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) | |||
| self.deep_reshape.add_prim_attr("skip_redistribution", True) | |||
| if target == 'DEVICE': | |||
| self.wide_mul.shard(((1, 1, 1), (1, 1, 1))) | |||
| if config.deep_table_slice_mode == "column_slice": | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) | |||
| self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) | |||
| self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) | |||
| self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | |||
| self.dense_layer_1.matmul.add_prim_attr("field_size", self.field_size) | |||
| self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) | |||
| self.deep_reshape.add_prim_attr("skip_redistribution", True) | |||
| else: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) | |||
| self.reduce_sum.add_prim_attr("cross_batch", True) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape: | |||
| @@ -247,13 +257,15 @@ class WideDeepModel(nn.Cell): | |||
| id_hldr: batch ids; | |||
| wt_hldr: batch weights; | |||
| """ | |||
| mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | |||
| # Wide layer | |||
| wide_id_weight = self.wide_embeddinglookup(id_hldr) | |||
| # Deep layer | |||
| deep_id_embs = self.deep_embeddinglookup(id_hldr) | |||
| mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | |||
| # Wide layer | |||
| wx = self.wide_mul(wide_id_weight, mask) | |||
| wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | |||
| # Deep layer | |||
| deep_id_embs = self.deep_embeddinglookup(id_hldr) | |||
| vx = self.deep_mul(deep_id_embs, mask) | |||
| deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) | |||
| deep_in = self.dense_layer_1(deep_in) | |||
| @@ -333,7 +345,8 @@ class TrainStepWrap(nn.Cell): | |||
| parameter_server (Bool): Whether run in parameter server mode. Default: False | |||
| """ | |||
| def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, sparse=False): | |||
| def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, | |||
| sparse=False): | |||
| super(TrainStepWrap, self).__init__() | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| @@ -40,8 +40,8 @@ def get_WideDeep_net(config): | |||
| WideDeep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| loss_net = VirtualDatasetCellTriple(loss_net) | |||
| train_net = TrainStepWrap( | |||
| loss_net, host_device_mix=bool(config.host_device_mix), sparse=config.sparse) | |||
| train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix), | |||
| sparse=config.sparse) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| eval_net = VirtualDatasetCellTriple(eval_net) | |||
| return train_net, eval_net | |||
| @@ -122,7 +122,7 @@ def train_and_eval(config): | |||
| metrics={"auc": auc_metric}) | |||
| eval_callback = EvalCallBack( | |||
| model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) | |||
| model, ds_eval, auc_metric, config) | |||
| callback = LossCallBack(config=config, per_print_times=20) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, | |||
| @@ -146,7 +146,7 @@ if __name__ == "__main__": | |||
| context.set_context(variable_memory_max_size="24GB") | |||
| context.set_context(enable_sparse=True) | |||
| init() | |||
| if wide_deep_config.host_device_mix == 1: | |||
| if wide_deep_config.sparse: | |||
| context.set_auto_parallel_context( | |||
| parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True) | |||
| else: | |||
| @@ -37,6 +37,8 @@ def argparse_init(): | |||
| parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") | |||
| parser.add_argument("--eval_file_name", type=str, default="eval.log") | |||
| parser.add_argument("--loss_file_name", type=str, default="loss.log") | |||
| parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not") | |||
| parser.add_argument("--deep_table_slice_mode", type=str, default="column_slice", help="column_slice/row_slice") | |||
| return parser | |||
| @@ -66,6 +68,8 @@ class WideDeepConfig(): | |||
| self.loss_file_name = "loss.log" | |||
| self.ckpt_path = "./checkpoints/" | |||
| self.stra_ckpt = "./strategy_train.ckpt" | |||
| self.sparse = False | |||
| self.deep_table_slice_mode = "column_slice" | |||
| def argparse_init(self): | |||
| """ | |||
| @@ -94,3 +98,7 @@ class WideDeepConfig(): | |||
| self.loss_file_name = args.loss_file_name | |||
| self.ckpt_path = args.ckpt_path | |||
| self.stra_ckpt = args.stra_ckpt | |||
| self.sparse = bool(args.sparse) | |||
| self.deep_table_slice_mode = args.deep_table_slice_mode | |||
| if self.host_device_mix == 1: | |||
| self.sparse = True | |||
| @@ -93,7 +93,7 @@ def test_unique_row_split(): | |||
| self.embedding_lookp = P.GatherV2().shard(((8, 1), (1,))) | |||
| self.embedding_table = Parameter(initializer('normal', [2000, 128]), | |||
| name='embedding_table') | |||
| self.gatherv2 = P.GatherV2().shard(((1, 1), (8,))) | |||
| self.gatherv2 = P.GatherV2().shard(((1, 1), (1,))) | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul() | |||
| self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight") | |||
| @@ -108,7 +108,7 @@ def test_unique_row_split(): | |||
| return vx | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="stand_alone") | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| x = Tensor(np.ones([32, 64]), dtype=ms.int32) | |||
| net = Net() | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||