Merge pull request !6417 from yao_yf/auto_parallel_func_mvtags/v1.0.0
| @@ -1542,17 +1542,8 @@ size_t CostGraph::GetNumEdges() const { | |||||
| } | } | ||||
| return sum; | return sum; | ||||
| } | } | ||||
| Status CostGraph::InitSelectedStrategy() { | |||||
| for (auto &op : ops_) { | |||||
| MS_EXCEPTION_IF_NULL(op); | |||||
| if (op->name().find(RESHAPEINFO) != std::string::npos) { | |||||
| continue; | |||||
| } | |||||
| auto result = op->InitSelectedStrategy(op->selected_strategy()); | |||||
| if (result != SUCCESS) { | |||||
| return result; | |||||
| } | |||||
| } | |||||
| Status CostGraph::InitReshapeStrategy() { | |||||
| // reshape init should be apply after the init of it's previous node and next node. | // reshape init should be apply after the init of it's previous node and next node. | ||||
| for (size_t i = 0; i < ops_.size(); ++i) { | for (size_t i = 0; i < ops_.size(); ++i) { | ||||
| if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { | if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { | ||||
| @@ -1606,6 +1597,21 @@ Status CostGraph::InitSelectedStrategy() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::InitSelectedStrategy() { | |||||
| for (auto &op : ops_) { | |||||
| MS_EXCEPTION_IF_NULL(op); | |||||
| if (op->name().find(RESHAPEINFO) != std::string::npos) { | |||||
| continue; | |||||
| } | |||||
| auto result = op->InitSelectedStrategy(op->selected_strategy()); | |||||
| if (result != SUCCESS) { | |||||
| return result; | |||||
| } | |||||
| } | |||||
| auto result = InitReshapeStrategy(); | |||||
| return result; | |||||
| } | |||||
| Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | ||||
| for (auto &op : ops_) { | for (auto &op : ops_) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| @@ -186,6 +186,7 @@ class CostGraph { | |||||
| std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | ||||
| size_t GetNumEdges() const; | size_t GetNumEdges() const; | ||||
| Status InitReshapeStrategy(); | |||||
| Status InitSelectedStrategy(); | Status InitSelectedStrategy(); | ||||
| OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | ||||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | ||||
| @@ -2275,7 +2275,6 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) | |||||
| } | } | ||||
| void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | ||||
| MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; | |||||
| StrategyMap stra_map; | StrategyMap stra_map; | ||||
| TensorInfoMap tensor_info_map; | TensorInfoMap tensor_info_map; | ||||
| ManualShapeMap manual_shape_map; | ManualShapeMap manual_shape_map; | ||||
| @@ -2298,10 +2297,8 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info(); | std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info(); | ||||
| StrategyPtr strategyPtr = operator_info->strategy(); | |||||
| MS_EXCEPTION_IF_NULL(node->scope()); | |||||
| std::string stratey_key_name = prim->name() + "_" + param_name; | std::string stratey_key_name = prim->name() + "_" + param_name; | ||||
| stra_map[stratey_key_name] = strategyPtr; | |||||
| stra_map[stratey_key_name] = operator_info->strategy(); | |||||
| for (auto param_name_pair : param_names) { | for (auto param_name_pair : param_names) { | ||||
| if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { | if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { | ||||
| continue; | continue; | ||||
| @@ -395,9 +395,10 @@ def set_auto_parallel_context(**kwargs): | |||||
| should be set with True. Default: False. | should be set with True. Default: False. | ||||
| enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for | enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for | ||||
| data parallel training in the benefit of time and memory saving. For now, | data parallel training in the benefit of time and memory saving. For now, | ||||
| `Lamb` and `AdamWeightDecay` are supported in data parallel mode. | |||||
| `Lamb` and `AdamWeightDecay` are supported in data parallel mode. No Default, if it is not set, | |||||
| the fusion is closed. | |||||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | ||||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. | |||||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. | |||||
| Raises: | Raises: | ||||
| ValueError: If input key is not attribute in auto parallel context. | ValueError: If input key is not attribute in auto parallel context. | ||||
| @@ -408,9 +409,13 @@ def set_auto_parallel_context(**kwargs): | |||||
| >>> context.set_auto_parallel_context(gradients_mean=True) | >>> context.set_auto_parallel_context(gradients_mean=True) | ||||
| >>> context.set_auto_parallel_context(gradient_fp32_sync=False) | >>> context.set_auto_parallel_context(gradient_fp32_sync=False) | ||||
| >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") | >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") | ||||
| >>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming") | |||||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | >>> context.set_auto_parallel_context(parameter_broadcast=False) | ||||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | ||||
| >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | ||||
| >>> context.set_auto_parallel_context(full_batch=True) | |||||
| >>> context.set_auto_parallel_context(enable_parallel_optimizer=False) | |||||
| >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) | |||||
| """ | """ | ||||
| _set_auto_parallel_context(**kwargs) | _set_auto_parallel_context(**kwargs) | ||||
| @@ -439,10 +444,12 @@ def reset_auto_parallel_context(): | |||||
| - global_rank: 0. | - global_rank: 0. | ||||
| - gradients_mean: False. | - gradients_mean: False. | ||||
| - gradient_fp32_sync: True. | - gradient_fp32_sync: True. | ||||
| - parallel_mode: "stand_alone". | |||||
| - parallel_mode: 'stand_alone'. | |||||
| - auto_parallel_search_mode: 'dynamic_programming'. | |||||
| - parameter_broadcast: False. | - parameter_broadcast: False. | ||||
| - strategy_ckpt_load_file: "". | |||||
| - strategy_ckpt_save_file: "". | |||||
| - strategy_ckpt_load_file: ''. | |||||
| - strategy_ckpt_save_file: ''. | |||||
| - full_batch: False. | |||||
| - enable_parallel_optimizer: False. | - enable_parallel_optimizer: False. | ||||
| """ | """ | ||||
| _reset_auto_parallel_context() | _reset_auto_parallel_context() | ||||
| @@ -245,6 +245,10 @@ class _AutoParallelContext: | |||||
| strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. | strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. | ||||
| """ | """ | ||||
| self.check_context_handle() | self.check_context_handle() | ||||
| import os | |||||
| dir_path = os.path.dirname(strategy_ckpt_save_file) | |||||
| if dir_path and not os.path.exists(dir_path): | |||||
| os.makedirs(dir_path) | |||||
| self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) | self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) | ||||
| def get_strategy_ckpt_save_file(self): | def get_strategy_ckpt_save_file(self): | ||||