| @@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo { | |||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | ||||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | std::shared_ptr<Strategys> GenerateBatchStrategies() override; | ||||
| const std::vector<int64_t> ¶m_split_shapes() const { return param_split_shapes_; } | |||||
| const std::vector<int64_t> &index_offsets() const { return index_offsets_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | Status CheckStrategy(const StrategyPtr &strategy) override; | ||||
| @@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| operator_info->set_outputs_dtype(cnode->Type()); | operator_info->set_outputs_dtype(cnode->Type()); | ||||
| operator_info->set_cnode(cnode); | operator_info->set_cnode(cnode); | ||||
| // key of strategy map | // key of strategy map | ||||
| std::string strategy_key_name = NodeParameterName(cnode); | |||||
| std::string strategy_key_name = ""; | |||||
| auto param_names = NodeParameterName(cnode); | |||||
| if (!param_names.empty()) { | |||||
| strategy_key_name = param_names[0].first; | |||||
| } | |||||
| bool load_strategy_from_ckpt = | bool load_strategy_from_ckpt = | ||||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | ||||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | // If no strategy has been configured for this operator, then candidate strategies are generated for | ||||
| @@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| // load strategy checkpoint | // load strategy checkpoint | ||||
| // key of strategy map | // key of strategy map | ||||
| std::string strategy_key_name = NodeParameterName(cnode); | |||||
| std::string strategy_key_name = ""; | |||||
| auto param_names = NodeParameterName(cnode); | |||||
| if (!param_names.empty()) { | |||||
| strategy_key_name = param_names[0].first; | |||||
| } | |||||
| bool load_strategy_from_ckpt = | bool load_strategy_from_ckpt = | ||||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | ||||
| if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { | if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { | ||||
| @@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo | |||||
| } | } | ||||
| } | } | ||||
| std::string NodeParameterName(const CNodePtr &node) { | |||||
| std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) { | |||||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | std::vector<AnfNodePtr> node_inputs{node->inputs()}; | ||||
| for (auto input : node_inputs) { | |||||
| std::vector<std::pair<std::string, int>> param_names; | |||||
| for (int i = 0; i < UintToInt(node_inputs.size()); ++i) { | |||||
| auto input = node_inputs[i]; | |||||
| if (input->isa<Parameter>()) { | if (input->isa<Parameter>()) { | ||||
| auto input_parameter = input->cast<ParameterPtr>(); | auto input_parameter = input->cast<ParameterPtr>(); | ||||
| if (input_parameter->has_default()) { | if (input_parameter->has_default()) { | ||||
| input_parameter->name(); | |||||
| if (ParameterRequireGrad(input_parameter)) { | |||||
| param_names.push_back({input_parameter->name(), i}); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return ""; | |||||
| return param_names; | |||||
| } | } | ||||
| void CheckpointStrategy(const FuncGraphPtr &func_graph) { | void CheckpointStrategy(const FuncGraphPtr &func_graph) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; | MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; | ||||
| StrategyMap stra_map; | StrategyMap stra_map; | ||||
| TensorInfoMap tensor_info_map; | |||||
| ManualShapeMap manual_shape_map; | |||||
| auto ret = func_graph->get_return(); | auto ret = func_graph->get_return(); | ||||
| auto all_nodes = DeepScopedGraphSearch(ret); | auto all_nodes = DeepScopedGraphSearch(ret); | ||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| @@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::string param_name = NodeParameterName(cnode); | |||||
| if (param_name.empty()) { | |||||
| auto param_names = NodeParameterName(cnode); | |||||
| if (param_names.empty()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| string param_name = param_names[0].first; | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); | OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); | ||||
| @@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||||
| if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info(); | |||||
| StrategyPtr strategyPtr = operator_info->strategy(); | StrategyPtr strategyPtr = operator_info->strategy(); | ||||
| MS_EXCEPTION_IF_NULL(node->scope()); | MS_EXCEPTION_IF_NULL(node->scope()); | ||||
| stra_map[param_name] = strategyPtr; | stra_map[param_name] = strategyPtr; | ||||
| for (auto param_name_pair : param_names) { | |||||
| if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { | |||||
| continue; | |||||
| } | |||||
| tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1]; | |||||
| } | |||||
| if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos || | |||||
| operator_info->name().find(GATHERV2) != std::string::npos) { | |||||
| auto gatherv2_info = std::dynamic_pointer_cast<GatherV2PInfo>(operator_info); | |||||
| auto param_split_shapes = gatherv2_info->param_split_shapes(); | |||||
| auto index_offsets = gatherv2_info->index_offsets(); | |||||
| if (param_split_shapes.size() != index_offsets.size()) { | |||||
| MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same."; | |||||
| } | |||||
| std::vector<std::pair<int32_t, int32_t>> manual_shape; | |||||
| for (int i = 0; i < UintToInt(param_split_shapes.size()); ++i) { | |||||
| manual_shape.push_back({param_split_shapes[i], index_offsets[i]}); | |||||
| } | |||||
| manual_shape_map[param_name] = manual_shape; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { | |||||
| if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; | MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); | |||||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | ||||
| const FuncGraphManagerPtr &manager); | const FuncGraphManagerPtr &manager); | ||||
| std::string NodeParameterName(const CNodePtr &node); | |||||
| std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node); | |||||
| void CheckpointStrategy(const FuncGraphPtr &func_graph); | void CheckpointStrategy(const FuncGraphPtr &func_graph); | ||||
| @@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { | |||||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, | |||||
| ManualShapeMap *manual_shape_map) { | |||||
| straspb::ParallelStrategyMap parallel_strategy_map; | straspb::ParallelStrategyMap parallel_strategy_map; | ||||
| parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); | parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); | ||||
| for (auto &node_stra : strategy_map) { | for (auto &node_stra : strategy_map) { | ||||
| @@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| for (auto &node_tensor_info : tensor_info_map) { | |||||
| TensorInfo tensor_info = node_tensor_info.second; | |||||
| TensorLayout tensor_layout = tensor_info.tensor_layout(); | |||||
| straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); | |||||
| MS_EXCEPTION_IF_NULL(parallel_layout_item); | |||||
| parallel_layout_item->set_param_name(node_tensor_info.first); | |||||
| straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); | |||||
| straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); | |||||
| MS_EXCEPTION_IF_NULL(dev_matrix); | |||||
| for (auto dim : tensor_layout.device_arrangement().array()) { | |||||
| dev_matrix->add_dim(IntToUint(dim)); | |||||
| } | |||||
| straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_map); | |||||
| for (auto dim : tensor_layout.tensor_map().array()) { | |||||
| tensor_map->add_dim(dim); | |||||
| } | |||||
| straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); | |||||
| straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset(); | |||||
| MS_EXCEPTION_IF_NULL(manual_shape_map); | |||||
| auto manual_shape = (*manual_shape_map)[node_tensor_info.first]; | |||||
| for (auto dim_pair : manual_shape) { | |||||
| param_split_shape->add_dim(dim_pair.first); | |||||
| indices_offset->add_dim(dim_pair.second); | |||||
| } | |||||
| } | |||||
| std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); | std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); | ||||
| if (!parallel_strategy_map.SerializeToOstream(&output)) { | if (!parallel_strategy_map.SerializeToOstream(&output)) { | ||||
| MS_LOG(ERROR) << "Save strategy file failed"; | MS_LOG(ERROR) << "Save strategy file failed"; | ||||
| @@ -19,13 +19,19 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "frontend/parallel/ops_info/ops_utils.h" | #include "frontend/parallel/ops_info/ops_utils.h" | ||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_layout.h" | |||||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| using StrategyMap = std::unordered_map<std::string, StrategyPtr>; | using StrategyMap = std::unordered_map<std::string, StrategyPtr>; | ||||
| using TensorInfoMap = std::unordered_map<std::string, TensorInfo>; | |||||
| using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int32_t, int32_t>>>; | |||||
| class StrategyCheckpoint { | class StrategyCheckpoint { | ||||
| public: | public: | ||||
| StrategyCheckpoint() { | StrategyCheckpoint() { | ||||
| @@ -38,7 +44,7 @@ class StrategyCheckpoint { | |||||
| ~StrategyCheckpoint() = default; | ~StrategyCheckpoint() = default; | ||||
| Status Load(StrategyMap *strategy_map); | Status Load(StrategyMap *strategy_map); | ||||
| Status Save(const StrategyMap &strategy_map); | |||||
| Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map); | |||||
| static StrategyCheckpoint &GetInstance(); | static StrategyCheckpoint &GetInstance(); | ||||
| bool LoadCheckPointOn() const { return load_checkpoint_on_; } | bool LoadCheckPointOn() const { return load_checkpoint_on_; } | ||||
| @@ -32,7 +32,36 @@ message ParallelStrategyItem { | |||||
| required ParallelStrategys parallel_strategys = 2; | required ParallelStrategys parallel_strategys = 2; | ||||
| } | } | ||||
| message DevMatrix { | |||||
| repeated uint32 dim = 1; | |||||
| } | |||||
| message TensorMap { | |||||
| repeated int32 dim = 1; | |||||
| } | |||||
| message ParamSplitShape { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| message IndicesOffset { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| message ParallelLayouts { | |||||
| repeated DevMatrix dev_matrix = 1; | |||||
| repeated TensorMap tensor_map = 2; | |||||
| repeated ParamSplitShape param_split_shape = 3; | |||||
| repeated IndicesOffset indices_offset = 4; | |||||
| } | |||||
| message ParallelLayoutItem { | |||||
| required string param_name = 1; | |||||
| required ParallelLayouts parallel_layouts = 2; | |||||
| } | |||||
| message ParallelStrategyMap { | message ParallelStrategyMap { | ||||
| required uint32 current_stage = 1; | required uint32 current_stage = 1; | ||||
| repeated ParallelStrategyItem parallel_strategy_item = 2; | repeated ParallelStrategyItem parallel_strategy_item = 2; | ||||
| repeated ParallelLayoutItem parallel_layout_item = 3; | |||||
| } | } | ||||
| @@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f | |||||
| Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } | Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } | ||||
| Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { return SUCCESS; } | |||||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, | |||||
| ManualShapeMap *manual_shape_map) { return SUCCESS; } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||