| @@ -46,6 +46,8 @@ class GatherV2PInfo : public OperatorInfo { | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| const std::vector<int64_t> ¶m_split_shapes() const { return param_split_shapes_; } | |||
| const std::vector<int64_t> &index_offsets() const { return index_offsets_; } | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| @@ -334,7 +334,11 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| operator_info->set_outputs_dtype(cnode->Type()); | |||
| operator_info->set_cnode(cnode); | |||
| // key of strategy map | |||
| std::string strategy_key_name = NodeParameterName(cnode); | |||
| std::string strategy_key_name = ""; | |||
| auto param_names = NodeParameterName(cnode); | |||
| if (!param_names.empty()) { | |||
| strategy_key_name = param_names[0].first; | |||
| } | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | |||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | |||
| @@ -1480,7 +1480,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| // load strategy checkpoint | |||
| // key of strategy map | |||
| std::string strategy_key_name = NodeParameterName(cnode); | |||
| std::string strategy_key_name = ""; | |||
| auto param_names = NodeParameterName(cnode); | |||
| if (!param_names.empty()) { | |||
| strategy_key_name = param_names[0].first; | |||
| } | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | |||
| if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { | |||
| @@ -2118,23 +2122,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo | |||
| } | |||
| } | |||
| std::string NodeParameterName(const CNodePtr &node) { | |||
| std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node) { | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| for (auto input : node_inputs) { | |||
| std::vector<std::pair<std::string, int>> param_names; | |||
| for (int i = 0; i < UintToInt(node_inputs.size()); ++i) { | |||
| auto input = node_inputs[i]; | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| if (input_parameter->has_default()) { | |||
| input_parameter->name(); | |||
| if (ParameterRequireGrad(input_parameter)) { | |||
| param_names.push_back({input_parameter->name(), i}); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return ""; | |||
| return param_names; | |||
| } | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; | |||
| StrategyMap stra_map; | |||
| TensorInfoMap tensor_info_map; | |||
| ManualShapeMap manual_shape_map; | |||
| auto ret = func_graph->get_return(); | |||
| auto all_nodes = DeepScopedGraphSearch(ret); | |||
| for (auto &node : all_nodes) { | |||
| @@ -2143,10 +2153,11 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| continue; | |||
| } | |||
| std::string param_name = NodeParameterName(cnode); | |||
| if (param_name.empty()) { | |||
| auto param_names = NodeParameterName(cnode); | |||
| if (param_names.empty()) { | |||
| continue; | |||
| } | |||
| string param_name = param_names[0].first; | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); | |||
| @@ -2154,12 +2165,33 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | |||
| continue; | |||
| } | |||
| std::vector<TensorInfo> input_tensor_info = operator_info->inputs_tensor_info(); | |||
| StrategyPtr strategyPtr = operator_info->strategy(); | |||
| MS_EXCEPTION_IF_NULL(node->scope()); | |||
| stra_map[param_name] = strategyPtr; | |||
| for (auto param_name_pair : param_names) { | |||
| if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { | |||
| continue; | |||
| } | |||
| tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1]; | |||
| } | |||
| if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos || | |||
| operator_info->name().find(GATHERV2) != std::string::npos) { | |||
| auto gatherv2_info = std::dynamic_pointer_cast<GatherV2PInfo>(operator_info); | |||
| auto param_split_shapes = gatherv2_info->param_split_shapes(); | |||
| auto index_offsets = gatherv2_info->index_offsets(); | |||
| if (param_split_shapes.size() != index_offsets.size()) { | |||
| MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same."; | |||
| } | |||
| std::vector<std::pair<int32_t, int32_t>> manual_shape; | |||
| for (int i = 0; i < UintToInt(param_split_shapes.size()); ++i) { | |||
| manual_shape.push_back({param_split_shapes[i], index_offsets[i]}); | |||
| } | |||
| manual_shape_map[param_name] = manual_shape; | |||
| } | |||
| } | |||
| } | |||
| if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { | |||
| if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; | |||
| } | |||
| } | |||
| @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); | |||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | |||
| const FuncGraphManagerPtr &manager); | |||
| std::string NodeParameterName(const CNodePtr &node); | |||
| std::vector<std::pair<std::string, int>> NodeParameterName(const CNodePtr &node); | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph); | |||
| @@ -84,7 +84,8 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { | |||
| return SUCCESS; | |||
| } | |||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { | |||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, | |||
| ManualShapeMap *manual_shape_map) { | |||
| straspb::ParallelStrategyMap parallel_strategy_map; | |||
| parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); | |||
| for (auto &node_stra : strategy_map) { | |||
| @@ -103,6 +104,33 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { | |||
| } | |||
| } | |||
| } | |||
| for (auto &node_tensor_info : tensor_info_map) { | |||
| TensorInfo tensor_info = node_tensor_info.second; | |||
| TensorLayout tensor_layout = tensor_info.tensor_layout(); | |||
| straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); | |||
| MS_EXCEPTION_IF_NULL(parallel_layout_item); | |||
| parallel_layout_item->set_param_name(node_tensor_info.first); | |||
| straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); | |||
| straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); | |||
| MS_EXCEPTION_IF_NULL(dev_matrix); | |||
| for (auto dim : tensor_layout.device_arrangement().array()) { | |||
| dev_matrix->add_dim(IntToUint(dim)); | |||
| } | |||
| straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); | |||
| MS_EXCEPTION_IF_NULL(tensor_map); | |||
| for (auto dim : tensor_layout.tensor_map().array()) { | |||
| tensor_map->add_dim(dim); | |||
| } | |||
| straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); | |||
| straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset(); | |||
| MS_EXCEPTION_IF_NULL(manual_shape_map); | |||
| auto manual_shape = (*manual_shape_map)[node_tensor_info.first]; | |||
| for (auto dim_pair : manual_shape) { | |||
| param_split_shape->add_dim(dim_pair.first); | |||
| indices_offset->add_dim(dim_pair.second); | |||
| } | |||
| } | |||
| std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); | |||
| if (!parallel_strategy_map.SerializeToOstream(&output)) { | |||
| MS_LOG(ERROR) << "Save strategy file failed"; | |||
| @@ -19,13 +19,19 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_layout.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| using StrategyMap = std::unordered_map<std::string, StrategyPtr>; | |||
| using TensorInfoMap = std::unordered_map<std::string, TensorInfo>; | |||
| using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int32_t, int32_t>>>; | |||
| class StrategyCheckpoint { | |||
| public: | |||
| StrategyCheckpoint() { | |||
| @@ -38,7 +44,7 @@ class StrategyCheckpoint { | |||
| ~StrategyCheckpoint() = default; | |||
| Status Load(StrategyMap *strategy_map); | |||
| Status Save(const StrategyMap &strategy_map); | |||
| Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map); | |||
| static StrategyCheckpoint &GetInstance(); | |||
| bool LoadCheckPointOn() const { return load_checkpoint_on_; } | |||
| @@ -32,7 +32,36 @@ message ParallelStrategyItem { | |||
| required ParallelStrategys parallel_strategys = 2; | |||
| } | |||
| message DevMatrix { | |||
| repeated uint32 dim = 1; | |||
| } | |||
| message TensorMap { | |||
| repeated int32 dim = 1; | |||
| } | |||
| message ParamSplitShape { | |||
| repeated int64 dim = 1; | |||
| } | |||
| message IndicesOffset { | |||
| repeated int64 dim = 1; | |||
| } | |||
| message ParallelLayouts { | |||
| repeated DevMatrix dev_matrix = 1; | |||
| repeated TensorMap tensor_map = 2; | |||
| repeated ParamSplitShape param_split_shape = 3; | |||
| repeated IndicesOffset indices_offset = 4; | |||
| } | |||
| message ParallelLayoutItem { | |||
| required string param_name = 1; | |||
| required ParallelLayouts parallel_layouts = 2; | |||
| } | |||
| message ParallelStrategyMap { | |||
| required uint32 current_stage = 1; | |||
| repeated ParallelStrategyItem parallel_strategy_item = 2; | |||
| repeated ParallelLayoutItem parallel_layout_item = 3; | |||
| } | |||
| @@ -29,6 +29,7 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return f | |||
| Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } | |||
| Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { return SUCCESS; } | |||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, | |||
| ManualShapeMap *manual_shape_map) { return SUCCESS; } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||