| @@ -98,7 +98,7 @@ class DeviceManager { | |||||
| std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list | std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list | ||||
| int64_t global_rank_ = 0; // the real rank in all devices | int64_t global_rank_ = 0; // the real rank in all devices | ||||
| int64_t stage_num_ = 0; // the stage num | |||||
| int64_t stage_num_ = 1; // the stage num | |||||
| int64_t stage_id_ = 0; // the stage id of the global_rank_ | int64_t stage_id_ = 0; // the stage id of the global_rank_ | ||||
| int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage | int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage | ||||
| int64_t stage_device_num_ = 0; // the device num of one stage | int64_t stage_device_num_ = 0; // the device num of one stage | ||||
| @@ -75,7 +75,8 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM, | |||||
| EMBED, | EMBED, | ||||
| CREATINSTANCE, | CREATINSTANCE, | ||||
| REF_TO_EMBED, | REF_TO_EMBED, | ||||
| STOP_GRADIENT}; | |||||
| STOP_GRADIENT, | |||||
| SEND}; | |||||
| const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER}; | const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER}; | ||||
| @@ -182,6 +182,8 @@ constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLog | |||||
| constexpr char MATMUL[] = "MatMul"; | constexpr char MATMUL[] = "MatMul"; | ||||
| constexpr char GELU[] = "Gelu"; | constexpr char GELU[] = "Gelu"; | ||||
| constexpr char TANH[] = "Tanh"; | constexpr char TANH[] = "Tanh"; | ||||
| constexpr char RECEIVE[] = "Receive"; | |||||
| constexpr char SEND[] = "Send"; | |||||
| constexpr char SHAPE_OP[] = "Shape"; | constexpr char SHAPE_OP[] = "Shape"; | ||||
| constexpr char SOFTMAX[] = "Softmax"; | constexpr char SOFTMAX[] = "Softmax"; | ||||
| constexpr char LOG_SOFTMAX[] = "LogSoftmax"; | constexpr char LOG_SOFTMAX[] = "LogSoftmax"; | ||||
| @@ -26,6 +26,8 @@ | |||||
| #include "frontend/parallel/ops_info/ops_utils.h" | #include "frontend/parallel/ops_info/ops_utils.h" | ||||
| #include "frontend/parallel/group_manager.h" | #include "frontend/parallel/group_manager.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "frontend/parallel/step_parallel.h" | |||||
| #include "frontend/parallel/node_check.h" | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| @@ -37,6 +39,7 @@ static int recv_tag = 0; | |||||
| void PipelineTransformer::Coloring() { | void PipelineTransformer::Coloring() { | ||||
| auto need_coloring = true; | auto need_coloring = true; | ||||
| std::set<int64_t> stage_set; | |||||
| while (need_coloring) { | while (need_coloring) { | ||||
| need_coloring = false; | need_coloring = false; | ||||
| for (auto &fg : manager_->func_graphs()) { | for (auto &fg : manager_->func_graphs()) { | ||||
| @@ -52,6 +55,9 @@ void PipelineTransformer::Coloring() { | |||||
| auto user_node = user_pair.first->cast<CNodePtr>(); | auto user_node = user_pair.first->cast<CNodePtr>(); | ||||
| user_node->set_stage(graph->stage()); | user_node->set_stage(graph->stage()); | ||||
| auto user_node_graph = user_node->func_graph(); | auto user_node_graph = user_node->func_graph(); | ||||
| if (graph->stage() != -1) { | |||||
| stage_set.insert(graph->stage()); | |||||
| } | |||||
| if (graph->stage() == stage_ && user_node_graph->stage() == -1) { | if (graph->stage() == stage_ && user_node_graph->stage() == -1) { | ||||
| user_node_graph->set_stage(graph->stage()); | user_node_graph->set_stage(graph->stage()); | ||||
| need_coloring = true; | need_coloring = true; | ||||
| @@ -60,6 +66,12 @@ void PipelineTransformer::Coloring() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| auto stage_num = g_device_manager->stage_num(); | |||||
| if (SizeToInt(stage_set.size()) != stage_num) { | |||||
| MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size(); | |||||
| } | |||||
| return; | |||||
| } | } | ||||
| void PipelineTransformer::BroadCastColoring() { | void PipelineTransformer::BroadCastColoring() { | ||||
| @@ -68,6 +80,96 @@ void PipelineTransformer::BroadCastColoring() { | |||||
| } | } | ||||
| } | } | ||||
| bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| if (prim == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (IsInBlackList(prim)) { | |||||
| MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name(); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPipelineCareNode(cnode)) { | |||||
| MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node."; | |||||
| } | |||||
| auto shape_list = ExtractShape(cnode); | |||||
| if (shape_list.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape."; | |||||
| } | |||||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == RESHAPE) { | |||||
| MS_LOG(EXCEPTION) << "Reshape op can't be a border."; | |||||
| } | |||||
| auto attrs = prim->attrs(); | |||||
| auto op_info = OperatorInstance(prim, attrs, shape_list); | |||||
| auto &inputs = cnode->inputs(); | |||||
| std::vector<ValuePtr> input_value; | |||||
| for (size_t index = 1; index < inputs.size(); ++index) { | |||||
| if (inputs[index]->isa<ValueNode>()) { | |||||
| input_value.push_back(GetValueNode(inputs[index])); | |||||
| } else { | |||||
| input_value.emplace_back(nullptr); | |||||
| } | |||||
| } | |||||
| op_info->set_input_value(input_value); | |||||
| op_info->set_outputs_dtype(cnode->Type()); | |||||
| op_info->set_cnode(cnode); | |||||
| StrategyPtr strategy = nullptr; | |||||
| if (!StrategyFound(attrs)) { | |||||
| strategy = GenerateBatchParallelStrategy(op_info, prim); | |||||
| } else { | |||||
| strategy = ExtractStrategy(attrs); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(strategy); | |||||
| if (op_info->Init(strategy) == FAILED) { | |||||
| MS_LOG(EXCEPTION) << "operator: " << prim->name() << " init failed."; | |||||
| } | |||||
| return op_info; | |||||
| } | |||||
| std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| OperatorInfoPtr op_info = nullptr; | |||||
| TensorInfo tensor_info; | |||||
| // op1(stage1)->op2(stage2) | |||||
| if (IsValueNode<Primitive>(cnode->input(0))) { | |||||
| op_info = CreateOpInfo(cnode); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| tensor_info = op_info->outputs_tensor_info()[0]; | |||||
| } else if (IsValueNode<FuncGraph>(cnode->input(0))) { | |||||
| auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto output = graph->output(); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| auto output_cnode = output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||||
| auto prim = GetValueNode<PrimitivePtr>(output_cnode->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == TUPLE_GETITEM) { | |||||
| auto index = GetTupleGetItemIndex(output_cnode); | |||||
| auto pre_getitem_node = output_cnode->input(1)->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pre_getitem_node); | |||||
| op_info = CreateOpInfo(pre_getitem_node); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| tensor_info = op_info->outputs_tensor_info()[index]; | |||||
| } else { | |||||
| op_info = CreateOpInfo(output_cnode); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| tensor_info = op_info->outputs_tensor_info()[0]; | |||||
| } | |||||
| } | |||||
| return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info)); | |||||
| } | |||||
| void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { | void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { | ||||
| auto need_coloring = true; | auto need_coloring = true; | ||||
| while (need_coloring) { | while (need_coloring) { | ||||
| @@ -168,26 +270,19 @@ void PipelineTransformer::ParameterColoring() { | |||||
| } | } | ||||
| } | } | ||||
| static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node) { | |||||
| abstract::ShapePtr shape_ptr; | |||||
| static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape) { | |||||
| TypePtr type; | TypePtr type; | ||||
| std::vector<int64_t> shape; | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) { | if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) { | ||||
| auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); | auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); | ||||
| auto graph_return = graph->get_return(); | auto graph_return = graph->get_return(); | ||||
| shape_ptr = dyn_cast<abstract::Shape>(graph_return->Shape()); | |||||
| type = graph_return->Type(); | type = graph_return->Type(); | ||||
| } else { | } else { | ||||
| shape_ptr = dyn_cast<abstract::Shape>(node->Shape()); | |||||
| type = node->Type(); | type = node->Type(); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(shape_ptr); | |||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| auto shape_int = shape_ptr->shape(); | |||||
| std::vector<ValuePtr> element; | std::vector<ValuePtr> element; | ||||
| std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(element), | |||||
| [](int elem) { return MakeValue(elem); }); | |||||
| std::transform(shape.begin(), shape.end(), std::back_inserter(element), [](int elem) { return MakeValue(elem); }); | |||||
| auto shape_list = std::make_shared<ValueList>(element); | auto shape_list = std::make_shared<ValueList>(element); | ||||
| auto tensor_type = type->cast<mindspore::TensorTypePtr>(); | auto tensor_type = type->cast<mindspore::TensorTypePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tensor_type); | MS_EXCEPTION_IF_NULL(tensor_type); | ||||
| @@ -203,16 +298,20 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod | |||||
| auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | ||||
| Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); | Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); | ||||
| OperatorAttrs attrs = {attr_tag, attr_rank}; | OperatorAttrs attrs = {attr_tag, attr_rank}; | ||||
| auto send_op = CreatOpInstance(attrs, "Send", "send"); | |||||
| auto send_op = CreatOpInstance(attrs, SEND, "send"); | |||||
| auto send_node = NewValueNode(send_op); | auto send_node = NewValueNode(send_op); | ||||
| auto prim = GetValueNode<PrimitivePtr>(send_node); | auto prim = GetValueNode<PrimitivePtr>(send_node); | ||||
| auto shape_type_pair = GetShapeType(parameter); | |||||
| auto op_info_pair = GetOpInfo(parameter); | |||||
| auto tensor_info = op_info_pair.second; | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||||
| auto slice_shape = tensor_info->slice_shape(); | |||||
| auto shape_type_pair = GetShapeType(parameter, slice_shape); | |||||
| prim->set_attr("shape", shape_type_pair.first); | prim->set_attr("shape", shape_type_pair.first); | ||||
| prim->set_attr("dtype", shape_type_pair.second); | prim->set_attr("dtype", shape_type_pair.second); | ||||
| std::vector<AnfNodePtr> send_input = {send_node, parameter}; | std::vector<AnfNodePtr> send_input = {send_node, parameter}; | ||||
| auto send = graph->NewCNode(send_input); | auto send = graph->NewCNode(send_input); | ||||
| OperatorAttrs depend_attrs; | OperatorAttrs depend_attrs; | ||||
| auto depend_op = CreatOpInstance(depend_attrs, "Depend", "depend"); | |||||
| auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend"); | |||||
| std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send}; | std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send}; | ||||
| auto depend = graph->NewCNode(depend_input); | auto depend = graph->NewCNode(depend_input); | ||||
| SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; | SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; | ||||
| @@ -223,15 +322,23 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode | |||||
| int index, int user_node_stage, int node_stage) { | int index, int user_node_stage, int node_stage) { | ||||
| Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); | Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); | ||||
| recv_tag += 1; | recv_tag += 1; | ||||
| auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | |||||
| auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; | |||||
| Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank)); | Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank)); | ||||
| auto shape_type_pair = GetShapeType(node); | |||||
| auto op_info_pair = GetOpInfo(node); | |||||
| auto tensor_info = op_info_pair.second; | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||||
| auto slice_shape = tensor_info->slice_shape(); | |||||
| auto shape_type_pair = GetShapeType(node, slice_shape); | |||||
| Attr attr_shape = std::make_pair("shape", shape_type_pair.first); | Attr attr_shape = std::make_pair("shape", shape_type_pair.first); | ||||
| Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); | Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); | ||||
| OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; | OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; | ||||
| auto recv_op = CreatOpInstance(attrs, "Receive", "recv"); | |||||
| auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv"); | |||||
| std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_}; | std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_}; | ||||
| auto recv = graph->NewCNode(recv_input); | auto recv = graph->NewCNode(recv_input); | ||||
| auto node_abstract = node->abstract(); | |||||
| recv->set_abstract(node_abstract); | |||||
| recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout())); | |||||
| recv->set_user_data<OperatorInfo>(op_info_pair.first); | |||||
| manager_->SetEdge(use_node, index, recv); | manager_->SetEdge(use_node, index, recv); | ||||
| } | } | ||||
| @@ -317,36 +424,10 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { | |||||
| void PipelineTransformer::CutGraph() { | void PipelineTransformer::CutGraph() { | ||||
| for (auto &fg : manager_->func_graphs()) { | for (auto &fg : manager_->func_graphs()) { | ||||
| if (fg == root_) { | |||||
| ElimRootParameter(); | |||||
| continue; | |||||
| } | |||||
| CutBorder(fg); | CutBorder(fg); | ||||
| } | } | ||||
| } | } | ||||
| void PipelineTransformer::ElimRootParameter() { | |||||
| auto output = root_->output()->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| auto prim = GetValueNode<PrimitivePtr>(output->input(0)); | |||||
| if (prim->name() == DEPEND) { | |||||
| auto opt_cnode = output->input(2)->cast<CNodePtr>(); | |||||
| auto prim_make_tuple = GetValueNode<PrimitivePtr>(opt_cnode->input(0)); | |||||
| if (prim_make_tuple->name() == MAKE_TUPLE) { | |||||
| std::vector<AnfNodePtr> new_node_input = {opt_cnode->input(0)}; | |||||
| for (auto &input : opt_cnode->inputs()) { | |||||
| if (input->isa<CNode>()) { | |||||
| if (IsStageNode(input->cast<CNodePtr>())) { | |||||
| new_node_input.push_back(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| auto new_node = root_->NewCNode(new_node_input); | |||||
| manager_->Replace(opt_cnode, new_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| bool PipelineTransformer::IsStageNode(const CNodePtr &node) { | bool PipelineTransformer::IsStageNode(const CNodePtr &node) { | ||||
| for (auto &input : node->inputs()) { | for (auto &input : node->inputs()) { | ||||
| if (input->isa<Parameter>()) { | if (input->isa<Parameter>()) { | ||||
| @@ -414,11 +495,16 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() { | |||||
| } | } | ||||
| void PipelineTransformer::CoverSensShape() { | void PipelineTransformer::CoverSensShape() { | ||||
| if (IsLastStage()) { | |||||
| return; | |||||
| } | |||||
| auto sens_graph_pair = FindSensNode(); | auto sens_graph_pair = FindSensNode(); | ||||
| auto sens_cnode = sens_graph_pair.first; | auto sens_cnode = sens_graph_pair.first; | ||||
| MS_EXCEPTION_IF_NULL(sens_cnode); | MS_EXCEPTION_IF_NULL(sens_cnode); | ||||
| OperatorAttrs attrs; | OperatorAttrs attrs; | ||||
| auto fill_op = CreatOpInstance(attrs, "Fill", ""); | auto fill_op = CreatOpInstance(attrs, "Fill", ""); | ||||
| MS_EXCEPTION_IF_NULL(type_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(shape_); | |||||
| std::vector<AnfNodePtr> fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_), | std::vector<AnfNodePtr> fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_), | ||||
| NewValueNode(MakeValue(shape_->value())), NewValueNode(0)}; | NewValueNode(MakeValue(shape_->value())), NewValueNode(0)}; | ||||
| auto fill = root_->NewCNode(fill_input); | auto fill = root_->NewCNode(fill_input); | ||||
| @@ -19,13 +19,18 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| #include "ir/graph_utils.h" | #include "ir/graph_utils.h" | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "frontend/parallel/step_parallel.h" | |||||
| #include "frontend/parallel/graph_util/generate_graph.h" | #include "frontend/parallel/graph_util/generate_graph.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| using TensorLayoutPtr = std::shared_ptr<TensorLayout>; | |||||
| using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||||
| typedef struct { | typedef struct { | ||||
| ValueListPtr shape; | ValueListPtr shape; | ||||
| TypePtr type; | TypePtr type; | ||||
| @@ -59,8 +64,10 @@ class PipelineTransformer { | |||||
| void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, | void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, | ||||
| int user_node_stage, int node_stage); | int user_node_stage, int node_stage); | ||||
| void CutBorder(const FuncGraphPtr &graph); | void CutBorder(const FuncGraphPtr &graph); | ||||
| void ElimRootParameter(); | |||||
| bool IsStageNode(const CNodePtr &node); | bool IsStageNode(const CNodePtr &node); | ||||
| std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node); | |||||
| OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode); | |||||
| bool IsPipelineCareNode(const CNodePtr &cnode); | |||||
| std::pair<CNodePtr, FuncGraphPtr> FindSensNode(); | std::pair<CNodePtr, FuncGraphPtr> FindSensNode(); | ||||
| FuncGraphManagerPtr manager_; | FuncGraphManagerPtr manager_; | ||||
| int64_t stage_; | int64_t stage_; | ||||
| @@ -1752,7 +1752,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||||
| SetVirtualDatasetStrategy(cnode); | SetVirtualDatasetStrategy(cnode); | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST) { | |||||
| if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto attrs = prim->attrs(); | auto attrs = prim->attrs(); | ||||
| @@ -2420,6 +2420,13 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP | |||||
| return sens_loss_pairs; | return sens_loss_pairs; | ||||
| } | } | ||||
| bool IsLastStage() { | |||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| auto stage_num = g_device_manager->stage_num(); | |||||
| auto stage_id = g_device_manager->stage_id(); | |||||
| return ((stage_num - 1) == stage_id); | |||||
| } | |||||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | ||||
| const FuncGraphManagerPtr &manager) { | const FuncGraphManagerPtr &manager) { | ||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| @@ -2432,7 +2439,9 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||||
| for (auto &pair : sens_loss_pairs) { | for (auto &pair : sens_loss_pairs) { | ||||
| // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. | // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. | ||||
| // If the type of sens node is not Tensor, it is unsupported now, do nothing default. | // If the type of sens node is not Tensor, it is unsupported now, do nothing default. | ||||
| StepSplitSens(pair); | |||||
| if (IsLastStage()) { | |||||
| StepSplitSens(pair); | |||||
| } | |||||
| } | } | ||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| @@ -2448,13 +2457,15 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||||
| MS_EXCEPTION_IF_NULL(distribute_operator); | MS_EXCEPTION_IF_NULL(distribute_operator); | ||||
| // insert forward ops | // insert forward ops | ||||
| InsertForwardOps(distribute_operator, cnode); | |||||
| if (!IsSomePrimitive(cnode, RECEIVE)) { | |||||
| InsertForwardOps(distribute_operator, cnode); | |||||
| } | |||||
| // insert redistribution ops | // insert redistribution ops | ||||
| StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); | StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); | ||||
| // insert backward ops | // insert backward ops | ||||
| if (has_backward) { | |||||
| if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) { | |||||
| BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); | BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); | ||||
| } | } | ||||
| @@ -2468,7 +2479,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -2895,7 +2906,7 @@ ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareN | |||||
| for (auto &candidate : candidate_set) { | for (auto &candidate : candidate_set) { | ||||
| auto candidate_node = candidate.first; | auto candidate_node = candidate.first; | ||||
| auto c = candidate_node->cast<CNodePtr>(); | auto c = candidate_node->cast<CNodePtr>(); | ||||
| if (c == nullptr || !c->has_user_data<OperatorInfo>()) { | |||||
| if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| (void)parameter_user_info.second.second.insert(candidate); | (void)parameter_user_info.second.second.insert(candidate); | ||||
| @@ -131,6 +131,10 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node); | |||||
| void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); | void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); | ||||
| StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim); | |||||
| bool IsLastStage(); | |||||
| // Add node for whole graph | // Add node for whole graph | ||||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | ||||
| const FuncGraphManagerPtr &manager); | const FuncGraphManagerPtr &manager); | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" | #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" | ||||
| #include "frontend/parallel/step_parallel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| @@ -59,7 +60,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num | |||||
| // Only auto_parallel and semi_auto_parallel support PipelineSplit | // Only auto_parallel and semi_auto_parallel support PipelineSplit | ||||
| bool PipelineSplit(const ResourcePtr &res) { | bool PipelineSplit(const ResourcePtr &res) { | ||||
| auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); | auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); | ||||
| if (parallel_mode != parallel::SEMI_AUTO_PARALLEL || parallel_mode != parallel::AUTO_PARALLEL) { | |||||
| if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) { | |||||
| MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split."; | MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split."; | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -80,6 +81,9 @@ bool PipelineSplit(const ResourcePtr &res) { | |||||
| } | } | ||||
| auto stage = InferStage(global_rank, stage_num, device_num); | auto stage = InferStage(global_rank, stage_num, device_num); | ||||
| auto per_stage_rank_num = device_num / stage_num; | auto per_stage_rank_num = device_num / stage_num; | ||||
| if (parallel::ParallelInit() != parallel::SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "parallel init failed."; | |||||
| } | |||||
| auto transformer = | auto transformer = | ||||
| std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num); | std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num); | ||||
| // step1: Do color graph | // step1: Do color graph | ||||
| @@ -20,9 +20,10 @@ from .. import operations as P | |||||
| from ...common.tensor import RowTensor | from ...common.tensor import RowTensor | ||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | ||||
| _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, | |||||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | |||||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| from ..operations._inner_ops import Send, Receive | |||||
| @bprop_getters.register(AllReduce) | @bprop_getters.register(AllReduce) | ||||
| @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Unique, GatherD, Identity, SequenceMask) | Unique, GatherD, Identity, SequenceMask) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice, Send, Receive, | |||||
| _VirtualDiv, _GetTensorSlice, | |||||
| _HostAllGather, _HostReduceScatter) | _HostAllGather, _HostReduceScatter) | ||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| TensorSummary, HistogramSummary, Print, Assert) | TensorSummary, HistogramSummary, Print, Assert) | ||||
| @@ -21,6 +21,7 @@ from ... import context | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register | from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register | ||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from ...communication.management import get_rank, GlobalComm, _get_group | |||||
| class ExtractImagePatches(PrimitiveWithInfer): | class ExtractImagePatches(PrimitiveWithInfer): | ||||
| @@ -371,6 +372,116 @@ class MatrixDiagPart(PrimitiveWithInfer): | |||||
| return out_shape | return out_shape | ||||
| class Send(PrimitiveWithInfer): | |||||
| """ | |||||
| Send tensors from src_rank to the specified dest_rank. | |||||
| Note: | |||||
| Send and Recveive must be used in combination and have same sr_tag. | |||||
| Send must be used between servers. | |||||
| Args: | |||||
| sr_tag (int): A required integer identifying the send/recv message tag. The message will | |||||
| will be received by the Receive op with the same "sr_tag". | |||||
| dest_rank (int): A required integer identifying the destination rank. | |||||
| group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> import mindspore.ops.operations as ops | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore.communication import init | |||||
| >>> from mindspore import Tensor | |||||
| >>> import numpy as np | |||||
| >>> | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.depend = ops.Depend() | |||||
| >>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||||
| >>> | |||||
| >>> def construct(self, x): | |||||
| >>> out = self.depend(x, self.send(x)) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||||
| >>> net = Net() | |||||
| >>> output = net(input_) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): | |||||
| self.rank = get_rank(_get_group(group)) | |||||
| self.sr_tag = sr_tag | |||||
| self.group = group | |||||
| def infer_shape(self, x_shape): | |||||
| self.add_prim_attr("shape", x_shape) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| self.add_prim_attr("dtype", x_dtype) | |||||
| return x_dtype | |||||
| class Receive(PrimitiveWithInfer): | |||||
| """ | |||||
| receive tensors from src_rank. | |||||
| Note: | |||||
| Send and Recveive must be used in combination and have same sr_tag. | |||||
| Receive must be used between servers. | |||||
| Args: | |||||
| sr_tag (int): A required integer identifying the send/recv message tag. The message will | |||||
| will be send by the Send op with the same "sr_tag". | |||||
| src_rank (int): A required integer identifying the source rank. | |||||
| shape (list[int]): A required list identifying the shape of the tensor to be received. | |||||
| dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types: | |||||
| int8, int16, int32, float16, float32. | |||||
| group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> import mindspore.ops.operations as ops | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore.communication import init | |||||
| >>> from mindspore import Tensor | |||||
| >>> import numpy as np | |||||
| >>> | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||||
| >>> group="hccl_world_group") | |||||
| >>> | |||||
| >>> def construct(self): | |||||
| >>> out = self.recv() | |||||
| >>> return out | |||||
| >>> | |||||
| >>> net = Net() | |||||
| >>> output = net() | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): | |||||
| self.rank = get_rank(_get_group(group)) | |||||
| self.tag = sr_tag | |||||
| self.shape = shape | |||||
| self.dtype = dtype | |||||
| self.group = group | |||||
| def infer_shape(self, x_shape=None): | |||||
| return self.shape | |||||
| def infer_dtype(self, x_dtype=None): | |||||
| return self.dtype | |||||
| class MatrixSetDiag(PrimitiveWithInfer): | class MatrixSetDiag(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Modifies the batched diagonal part of a batched tensor. | Modifies the batched diagonal part of a batched tensor. | ||||
| @@ -116,117 +116,6 @@ class AllReduce(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class Send(PrimitiveWithInfer): | |||||
| """ | |||||
| Send tensors from src_rank to the specified dest_rank. | |||||
| Note: | |||||
| Send and Recveive must be used in combination and have same sr_tag. | |||||
| Send must be used between servers. | |||||
| Args: | |||||
| sr_tag (int): A required integer identifying the send/recv message tag. The message will | |||||
| will be received by the Receive op with the same "sr_tag". | |||||
| dest_rank (int): A required integer identifying the destination rank. | |||||
| group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> import mindspore.ops.operations as ops | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore.communication import init | |||||
| >>> from mindspore import Tensor | |||||
| >>> import numpy as np | |||||
| >>> | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.depend = ops.Depend() | |||||
| >>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||||
| >>> | |||||
| >>> def construct(self, x): | |||||
| >>> out = self.depend(x, self.send(x)) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||||
| >>> net = Net() | |||||
| >>> output = net(input_) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): | |||||
| self.rank = get_rank(_get_group(group)) | |||||
| self.sr_tag = sr_tag | |||||
| self.group = group | |||||
| def infer_shape(self, x_shape): | |||||
| self.add_prim_attr("shape", x_shape) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| self.add_prim_attr("dtype", x_dtype) | |||||
| return x_dtype | |||||
| class Receive(PrimitiveWithInfer): | |||||
| """ | |||||
| receive tensors from src_rank. | |||||
| Note: | |||||
| Send and Recveive must be used in combination and have same sr_tag. | |||||
| Receive must be used between servers. | |||||
| Args: | |||||
| sr_tag (int): A required integer identifying the send/recv message tag. The message will | |||||
| will be send by the Send op with the same "sr_tag". | |||||
| src_rank (int): A required integer identifying the source rank. | |||||
| shape (list[int]): A required list identifying the shape of the tensor to be received. | |||||
| dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types: | |||||
| int8, int16, int32, float16, float32. | |||||
| group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> import mindspore.ops.operations as ops | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore.communication import init | |||||
| >>> from mindspore import Tensor | |||||
| >>> import numpy as np | |||||
| >>> | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||||
| >>> group="hccl_world_group") | |||||
| >>> | |||||
| >>> def construct(self, x): | |||||
| >>> out = self.depend(x, self.recv(x)) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||||
| >>> net = Net() | |||||
| >>> output = net(input_) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): | |||||
| self.rank = get_rank(_get_group(group)) | |||||
| self.tag = sr_tag | |||||
| self.shape = shape | |||||
| self.dtype = dtype | |||||
| self.group = group | |||||
| def infer_shape(self, x_shape=None): | |||||
| return self.shape | |||||
| def infer_dtype(self, x_dtype=None): | |||||
| return self.dtype | |||||
| class AllGather(PrimitiveWithInfer): | class AllGather(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Gathers tensors from the specified communication group. | Gathers tensors from the specified communication group. | ||||
| @@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size | from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations._inner_ops import Send, Receive | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| @@ -38,7 +39,7 @@ class SendNet(nn.Cell): | |||||
| super(SendNet, self).__init__() | super(SendNet, self).__init__() | ||||
| self.x = Parameter(initializer(Tensor(x), x.shape), name='x') | self.x = Parameter(initializer(Tensor(x), x.shape), name='x') | ||||
| self.depend = P.Depend() | self.depend = P.Depend() | ||||
| self.send = P.Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP) | |||||
| self.send = Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP) | |||||
| def construct(self): | def construct(self): | ||||
| out = self.depend(self.x, self.send(self.x)) | out = self.depend(self.x, self.send(self.x)) | ||||
| @@ -47,8 +48,8 @@ class SendNet(nn.Cell): | |||||
| class RecvNet(nn.Cell): | class RecvNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(RecvNet, self).__init__() | super(RecvNet, self).__init__() | ||||
| self.recv = P.Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32, | |||||
| group=NCCL_WORLD_COMM_GROUP) | |||||
| self.recv = Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32, | |||||
| group=NCCL_WORLD_COMM_GROUP) | |||||
| def construct(self): | def construct(self): | ||||
| out = self.recv() | out = self.recv() | ||||
| @@ -1,91 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| predict = self.network(x, y) | |||||
| return self.loss(predict) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| return grad_all(self.network)(x, y) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""): | |||||
| super().__init__() | |||||
| if shape is None: | |||||
| shape = [64, 64] | |||||
| self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target) | |||||
| self.mul = P.Mul().shard(strategy2) | |||||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||||
| self.gatherv2.set_stage(stage1) | |||||
| self.mul.set_stage(stage2) | |||||
| self.axis = axis | |||||
| def construct(self, x, y): | |||||
| out = self.gatherv2(x, self.index, self.axis) | |||||
| out = self.mul(out, y) | |||||
| return out | |||||
| def test_gatherv2_semi_samestage1(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, \ | |||||
| parallel_mode="semi_auto_parallel", pipeline_stages=2) | |||||
| strategy1 = ((1, 2), (1, 1)) | |||||
| strategy2 = ((2, 1, 1), (2, 1, 1)) | |||||
| net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2))) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||||
| net.set_train() | |||||
| _executor.compile(net, x, y) | |||||
| def test_gatherv2_semi_samestage2(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=5, \ | |||||
| parallel_mode="semi_auto_parallel", pipeline_stages=2) | |||||
| strategy1 = ((1, 2), (1, 1)) | |||||
| strategy2 = ((2, 1, 1), (2, 1, 1)) | |||||
| net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2))) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||||
| net.set_train() | |||||
| _executor.compile(net, x, y) | |||||
| @@ -0,0 +1,109 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.train.model import Model | |||||
| class DatasetLenet(): | |||||
| def __init__(self, data, label, length=3): | |||||
| self.data = data | |||||
| self.label = label | |||||
| self.index = 1 | |||||
| self.length = length | |||||
| def __iter__(self): | |||||
| return self | |||||
| def __next__(self): | |||||
| if self.index >= self.length: | |||||
| raise StopIteration | |||||
| self.index += 1 | |||||
| return self.data, self.label | |||||
| def reset(self): | |||||
| self.index = 0 | |||||
| def get_dataset_size(self): | |||||
| return 32 | |||||
| def get_repeat_count(self): | |||||
| return 1 | |||||
| def get_batch_size(self): | |||||
| return 32 | |||||
| def create_tuple_iterator(self, num_epochs=1): | |||||
| return self | |||||
| class MatMulCell(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.param = Parameter(initializer("zeros", [64, 64]), name="param") | |||||
| self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1") | |||||
| self.matmul = P.MatMul().shard(strategy1) | |||||
| self.matmul1 = P.MatMul().shard(strategy2) | |||||
| def construct(self, x): | |||||
| out = self.matmul(x, self.param) | |||||
| out = self.matmul1(out, self.param1) | |||||
| return out | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.block = nn.CellList() | |||||
| for i in range(2): | |||||
| cell = MatMulCell(strategy1, strategy2) | |||||
| cell.stage = i | |||||
| self.block.append(cell) | |||||
| def construct(self, x): | |||||
| for i in range(2): | |||||
| x = self.block[i](x) | |||||
| return x | |||||
| class PipelineSplit(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.cell = Net(strategy1, strategy2) | |||||
| def construct(self, x, label): | |||||
| x = self.cell(x) | |||||
| return x | |||||
| def test_pipeline_split(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| data = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||||
| label = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| strategy1 = ((4, 1), (1, 1)) | |||||
| strategy2 = ((2, 1), (1, 1)) | |||||
| net = PipelineSplit(strategy1, strategy2) | |||||
| params = net.cell.block[1].trainable_params() | |||||
| dataset = DatasetLenet(data, label, 3) | |||||
| optimizer = nn.Lamb(params, learning_rate=0.01) | |||||
| model = Model(net, optimizer=optimizer) | |||||
| model.train(2, dataset, dataset_sink_mode=False) | |||||