Merge pull request !5045 from yao_yf/wide_and_deep_eval_host_devicetags/v1.0.0
| @@ -343,7 +343,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| std::string strategy_key_name = ""; | |||
| auto param_names = NodeParameterName(cnode); | |||
| if (!param_names.empty()) { | |||
| strategy_key_name = param_names[0].first; | |||
| strategy_key_name = prim->name() + "_" + param_names[0].first; | |||
| } | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | |||
| @@ -1523,7 +1523,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| std::string strategy_key_name = ""; | |||
| auto param_names = NodeParameterName(cnode); | |||
| if (!param_names.empty()) { | |||
| strategy_key_name = param_names[0].first; | |||
| strategy_key_name = prim->name() + "_" + param_names[0].first; | |||
| } | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | |||
| @@ -2214,9 +2214,23 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) | |||
| auto input = node_inputs[i]; | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| if (input_parameter->has_default()) { | |||
| if (ParameterRequireGrad(input_parameter)) { | |||
| param_names.push_back({input_parameter->name(), i}); | |||
| if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) { | |||
| param_names.push_back({input_parameter->name(), i}); | |||
| } | |||
| } else if (input->isa<CNode>()) { | |||
| CNodePtr cnode = input->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return param_names; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (prim->name() == CAST && cnode->inputs().size() >= 1) { | |||
| auto cast_input = cnode->inputs()[1]; | |||
| if (cast_input->isa<Parameter>()) { | |||
| auto cast_input_parameter = cast_input->cast<ParameterPtr>(); | |||
| if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) { | |||
| param_names.push_back({cast_input_parameter->name(), i}); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -2224,14 +2238,11 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) | |||
| return param_names; | |||
| } | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; | |||
| StrategyMap stra_map; | |||
| TensorInfoMap tensor_info_map; | |||
| ManualShapeMap manual_shape_map; | |||
| auto ret = func_graph->get_return(); | |||
| auto all_nodes = DeepScopedGraphSearch(ret); | |||
| for (auto &node : all_nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -2253,7 +2264,8 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info(); | |||
| StrategyPtr strategyPtr = operator_info->strategy(); | |||
| MS_EXCEPTION_IF_NULL(node->scope()); | |||
| stra_map[param_name] = strategyPtr; | |||
| std::string stratey_key_name = prim->name() + "_" + param_name; | |||
| stra_map[stratey_key_name] = strategyPtr; | |||
| for (auto param_name_pair : param_names) { | |||
| if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { | |||
| continue; | |||
| @@ -2547,7 +2559,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| // save strategy as checkpoint for multi-train | |||
| if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { | |||
| CheckpointStrategy(root); | |||
| CheckpointStrategy(all_nodes); | |||
| } | |||
| HandleSymbolicKeyInstance(root, all_nodes); | |||
| @@ -136,7 +136,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||
| std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node); | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph); | |||
| void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes); | |||
| // main step of Parallel | |||
| bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); | |||
| @@ -152,7 +152,11 @@ optional arguments: | |||
| --keep_prob The keep rate in dropout layer.(Default:1.0) | |||
| --dropout_flag Enable dropout.(Default:0) | |||
| --output_path Deprecated | |||
| --ckpt_path The location of the checkpoint file.(Defalut:./checkpoints/) | |||
| --ckpt_path The location of the checkpoint file. If the checkpoint file | |||
| is a slice of weight, multiple checkpoint files need to be | |||
| transferred. Use ';' to separate them and sort them in sequence | |||
| like "./checkpoints/0.ckpt;./checkpoints/1.ckpt". | |||
| (Defalut:./checkpoints/) | |||
| --eval_file_name Eval output file.(Default:eval.og) | |||
| --loss_file_name Loss output file.(Default:loss.log) | |||
| --host_device_mix Enable host device mode or not.(Default:0) | |||
| @@ -18,7 +18,8 @@ | |||
| import os | |||
| from mindspore import Model, context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net,\ | |||
| build_searched_strategy, merge_sliced_parameter | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from src.callbacks import LossCallBack, EvalCallBack | |||
| @@ -81,8 +82,28 @@ def test_eval(config): | |||
| net_builder = ModelBuilder() | |||
| train_net, eval_net = net_builder.get_net(config) | |||
| param_dict = load_checkpoint(config.ckpt_path) | |||
| ckpt_path = config.ckpt_path | |||
| if ";" in ckpt_path: | |||
| ckpt_paths = ckpt_path.split(';') | |||
| param_list_dict = {} | |||
| strategy = build_searched_strategy(config.stra_ckpt) | |||
| for slice_path in ckpt_paths: | |||
| param_slice_dict = load_checkpoint(slice_path) | |||
| for key, value in param_slice_dict.items(): | |||
| if 'optimizer' in key: | |||
| continue | |||
| if key not in param_list_dict: | |||
| param_list_dict[key] = [] | |||
| param_list_dict[key].append(value) | |||
| param_dict = {} | |||
| for key, value in param_list_dict.items(): | |||
| if key in strategy: | |||
| merged_parameter = merge_sliced_parameter(value, strategy) | |||
| else: | |||
| merged_parameter = merge_sliced_parameter(value) | |||
| param_dict[key] = merged_parameter | |||
| else: | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(eval_net, param_dict) | |||
| auc_metric = AUCMetric() | |||
| @@ -97,6 +97,7 @@ class EvalCallBack(Callback): | |||
| self.eval_file_name = config.eval_file_name | |||
| self.eval_values = [] | |||
| self.host_device_mix = host_device_mix | |||
| self.config = config | |||
| def epoch_end(self, run_context): | |||
| """ | |||
| @@ -106,7 +107,7 @@ class EvalCallBack(Callback): | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file="", | |||
| strategy_ckpt_load_file="./strategy_train.ckpt") | |||
| strategy_ckpt_load_file=self.config.stra_ckpt) | |||
| rank_id = 0 | |||
| if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, | |||
| ParallelMode.DATA_PARALLEL): | |||
| @@ -39,6 +39,8 @@ def argparse_init(): | |||
| parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout") | |||
| parser.add_argument("--output_path", type=str, default="./output/") | |||
| parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.") | |||
| parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt", | |||
| help="The strategy checkpoint file.") | |||
| parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.") | |||
| parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.") | |||
| parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not") | |||
| @@ -75,6 +77,7 @@ class WideDeepConfig(): | |||
| self.eval_file_name = "eval.log" | |||
| self.loss_file_name = "loss.log" | |||
| self.ckpt_path = "./checkpoints/" | |||
| self.stra_ckpt = './checkpoints/strategy.ckpt' | |||
| self.host_device_mix = 0 | |||
| self.dataset_type = "tfrecord" | |||
| self.parameter_server = 0 | |||
| @@ -107,6 +110,7 @@ class WideDeepConfig(): | |||
| self.eval_file_name = args.eval_file_name | |||
| self.loss_file_name = args.loss_file_name | |||
| self.ckpt_path = args.ckpt_path | |||
| self.stra_ckpt = args.stra_ckpt | |||
| self.host_device_mix = args.host_device_mix | |||
| self.dataset_type = args.dataset_type | |||
| self.parameter_server = args.parameter_server | |||
| @@ -200,6 +200,7 @@ class WideDeepModel(nn.Cell): | |||
| self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) | |||
| self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) | |||
| self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) | |||
| self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size) | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | |||
| slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | |||
| @@ -208,6 +209,10 @@ class WideDeepModel(nn.Cell): | |||
| self.deep_reshape.add_prim_attr("skip_redistribution", True) | |||
| self.reduce_sum.add_prim_attr("cross_batch", True) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| elif host_device_mix: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| elif parameter_server: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) | |||
| @@ -111,10 +111,11 @@ def train_and_eval(config): | |||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) | |||
| callback = LossCallBack(config=config, per_print_times=20) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, | |||
| keep_checkpoint_max=5, integrated_save=False) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) | |||
| callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] | |||
| if not host_device_mix: | |||
| callback_list.append(ckpoint_cb) | |||
| @@ -30,6 +30,8 @@ def argparse_init(): | |||
| parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) | |||
| parser.add_argument("--deep_layer_act", type=str, default='relu') | |||
| parser.add_argument("--keep_prob", type=float, default=1.0) | |||
| parser.add_argument("--stra_ckpt", type=str, default="./strategy_train.ckpt", | |||
| help="The strategy checkpoint file.") | |||
| parser.add_argument("--output_path", type=str, default="./output/") | |||
| parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") | |||
| @@ -63,6 +65,7 @@ class WideDeepConfig(): | |||
| self.eval_file_name = "eval.log" | |||
| self.loss_file_name = "loss.log" | |||
| self.ckpt_path = "./checkpoints/" | |||
| self.stra_ckpt = "./strategy_train.ckpt" | |||
| def argparse_init(self): | |||
| """ | |||
| @@ -90,3 +93,4 @@ class WideDeepConfig(): | |||
| self.eval_file_name = args.eval_file_name | |||
| self.loss_file_name = args.loss_file_name | |||
| self.ckpt_path = args.ckpt_path | |||
| self.stra_ckpt = args.stra_ckpt | |||