| @@ -1542,17 +1542,8 @@ size_t CostGraph::GetNumEdges() const { | |||
| } | |||
| return sum; | |||
| } | |||
| Status CostGraph::InitSelectedStrategy() { | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| if (op->name().find(RESHAPEINFO) != std::string::npos) { | |||
| continue; | |||
| } | |||
| auto result = op->InitSelectedStrategy(op->selected_strategy()); | |||
| if (result != SUCCESS) { | |||
| return result; | |||
| } | |||
| } | |||
| Status CostGraph::InitReshapeStrategy() { | |||
| // reshape init should be apply after the init of it's previous node and next node. | |||
| for (size_t i = 0; i < ops_.size(); ++i) { | |||
| if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { | |||
| @@ -1606,6 +1597,21 @@ Status CostGraph::InitSelectedStrategy() { | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::InitSelectedStrategy() { | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| if (op->name().find(RESHAPEINFO) != std::string::npos) { | |||
| continue; | |||
| } | |||
| auto result = op->InitSelectedStrategy(op->selected_strategy()); | |||
| if (result != SUCCESS) { | |||
| return result; | |||
| } | |||
| } | |||
| auto result = InitReshapeStrategy(); | |||
| return result; | |||
| } | |||
| Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| @@ -186,6 +186,7 @@ class CostGraph { | |||
| std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | |||
| size_t GetNumEdges() const; | |||
| Status InitReshapeStrategy(); | |||
| Status InitSelectedStrategy(); | |||
| OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | |||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | |||
| @@ -2275,7 +2275,6 @@ std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) | |||
| } | |||
| void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; | |||
| StrategyMap stra_map; | |||
| TensorInfoMap tensor_info_map; | |||
| ManualShapeMap manual_shape_map; | |||
| @@ -2298,10 +2297,8 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| continue; | |||
| } | |||
| std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info(); | |||
| StrategyPtr strategyPtr = operator_info->strategy(); | |||
| MS_EXCEPTION_IF_NULL(node->scope()); | |||
| std::string stratey_key_name = prim->name() + "_" + param_name; | |||
| stra_map[stratey_key_name] = strategyPtr; | |||
| stra_map[stratey_key_name] = operator_info->strategy(); | |||
| for (auto param_name_pair : param_names) { | |||
| if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { | |||
| continue; | |||
| @@ -395,9 +395,10 @@ def set_auto_parallel_context(**kwargs): | |||
| should be set with True. Default: False. | |||
| enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for | |||
| data parallel training in the benefit of time and memory saving. For now, | |||
| `Lamb` and `AdamWeightDecay` are supported in data parallel mode. | |||
| `Lamb` and `AdamWeightDecay` are supported in data parallel mode. No Default, if it is not set, | |||
| the fusion is closed. | |||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | |||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. | |||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -408,9 +409,13 @@ def set_auto_parallel_context(**kwargs): | |||
| >>> context.set_auto_parallel_context(gradients_mean=True) | |||
| >>> context.set_auto_parallel_context(gradient_fp32_sync=False) | |||
| >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| >>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming") | |||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | |||
| >>> context.set_auto_parallel_context(full_batch=True) | |||
| >>> context.set_auto_parallel_context(enable_parallel_optimizer=False) | |||
| >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) | |||
| """ | |||
| _set_auto_parallel_context(**kwargs) | |||
| @@ -439,10 +444,12 @@ def reset_auto_parallel_context(): | |||
| - global_rank: 0. | |||
| - gradients_mean: False. | |||
| - gradient_fp32_sync: True. | |||
| - parallel_mode: "stand_alone". | |||
| - parallel_mode: 'stand_alone'. | |||
| - auto_parallel_search_mode: 'dynamic_programming'. | |||
| - parameter_broadcast: False. | |||
| - strategy_ckpt_load_file: "". | |||
| - strategy_ckpt_save_file: "". | |||
| - strategy_ckpt_load_file: ''. | |||
| - strategy_ckpt_save_file: ''. | |||
| - full_batch: False. | |||
| - enable_parallel_optimizer: False. | |||
| """ | |||
| _reset_auto_parallel_context() | |||
| @@ -245,6 +245,10 @@ class _AutoParallelContext: | |||
| strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. | |||
| """ | |||
| self.check_context_handle() | |||
| import os | |||
| dir_path = os.path.dirname(strategy_ckpt_save_file) | |||
| if dir_path and not os.path.exists(dir_path): | |||
| os.makedirs(dir_path) | |||
| self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) | |||
| def get_strategy_ckpt_save_file(self): | |||