| @@ -53,7 +53,7 @@ struct Cost { | |||||
| communication_redis_backward_ = 0.0; | communication_redis_backward_ = 0.0; | ||||
| communication_forward_ = 0.0; | communication_forward_ = 0.0; | ||||
| } | } | ||||
| // 'memory_with_reuse_' calculates the peak memory usage in a training phase | |||||
| // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase | |||||
| double memory_with_reuse_; | double memory_with_reuse_; | ||||
| // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated | // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated | ||||
| // by ONLY forward phase | // by ONLY forward phase | ||||
| @@ -300,5 +300,20 @@ Status Edge::CalculateMemoryCost() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status Edge::CalculateMemoryCostForInference() { | |||||
| // Currently, memory cost is NOT calculated for redistribution | |||||
| if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) { | |||||
| MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_; | |||||
| return FAILED; | |||||
| } | |||||
| for (auto &cost_kv : cost_map_) { | |||||
| auto &cost_v = cost_kv.second; | |||||
| if (!cost_v.empty()) { | |||||
| cost_v[0]->memory_with_reuse_ = 0; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -131,9 +131,13 @@ class Edge { | |||||
| void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } | void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } | ||||
| const CostPtr &selected_cost() const { return selected_cost_; } | const CostPtr &selected_cost() const { return selected_cost_; } | ||||
| void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } | void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } | ||||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | |||||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | |||||
| // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving | |||||
| // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory | |||||
| // at the end of forward phase. | |||||
| Status CalculateMemoryCost(); | Status CalculateMemoryCost(); | ||||
| // In the inference phase, | |||||
| Status CalculateMemoryCostForInference(); | |||||
| void mark_output_critical() { is_output_critical_ = 1; } | |||||
| private: | private: | ||||
| std::string edge_name_; | std::string edge_name_; | ||||
| @@ -156,7 +160,11 @@ class Edge { | |||||
| // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. | // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. | ||||
| bool is_identity_edge; | bool is_identity_edge; | ||||
| CostPtr selected_cost_; | CostPtr selected_cost_; | ||||
| // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator | |||||
| // is parameter-involved | |||||
| int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | ||||
| // In the inference phase, this is used to mark whether the output of the previous operator is critical. | |||||
| int is_output_critical_ = 0; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, | |||||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | ||||
| << ", communication_cost_: " << ret->communication_cost_ | << ", communication_cost_: " << ret->communication_cost_ | ||||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | ||||
| MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; | |||||
| MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; | |||||
| for (size_t i = 1; i < after_mem_filter.size(); ++i) { | for (size_t i = 1; i < after_mem_filter.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | ||||
| MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | ||||
| @@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d | |||||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | ||||
| << ", communication_cost_: " << ret->communication_cost_ | << ", communication_cost_: " << ret->communication_cost_ | ||||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | ||||
| MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; | |||||
| MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; | |||||
| for (size_t i = 1; i < after_mem_filter.size(); ++i) { | for (size_t i = 1; i < after_mem_filter.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | ||||
| MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | ||||
| @@ -1351,6 +1351,14 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||||
| return succ_edges; | return succ_edges; | ||||
| } | } | ||||
| size_t CostGraph::GetNumEdges() const { | |||||
| size_t sum = 0; | |||||
| for (const auto &kv : edges_) { | |||||
| auto &edges = kv.second; | |||||
| sum += edges.size(); | |||||
| } | |||||
| return sum; | |||||
| } | |||||
| Status CostGraph::InitSelectedStrategy() { | Status CostGraph::InitSelectedStrategy() { | ||||
| for (auto &op : ops_) { | for (auto &op : ops_) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| @@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited, | |||||
| std::vector<OperatorInfoPtr> *topo_order) { | |||||
| MS_EXCEPTION_IF_NULL(current_op); | |||||
| MS_EXCEPTION_IF_NULL(visited); | |||||
| MS_EXCEPTION_IF_NULL(topo_order); | |||||
| visited->at(current_op) = true; | |||||
| for (const auto &s_edge : current_op->succ_edges()) { | |||||
| if (!visited->at(s_edge->next_operator())) { | |||||
| DFSForTopoOrder(s_edge->next_operator(), visited, topo_order); | |||||
| } | |||||
| } | |||||
| topo_order->push_back(current_op); | |||||
| } | |||||
| // Compute a topological order of the costgraph | |||||
| void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) { | |||||
| std::map<OperatorInfoPtr, bool> visited; | |||||
| for (auto &op : ops_) { | |||||
| visited[op] = false; | |||||
| } | |||||
| for (auto &op : ops_) { | |||||
| if (!visited[op]) { | |||||
| DFSForTopoOrder(op, &visited, topo_order); | |||||
| } | |||||
| } | |||||
| } | |||||
| void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &candidate_ops) { | |||||
| for (auto &op : ops_) { | |||||
| auto search = candidate_ops.find(op); | |||||
| if (search != candidate_ops.end()) { | |||||
| // Mark the critical operators | |||||
| op->mark_output_critical(); | |||||
| // Mark the successive edges | |||||
| for (auto &s_edge : op->succ_edges()) { | |||||
| s_edge->mark_output_critical(); | |||||
| } | |||||
| } else { | |||||
| op->mark_output_not_critical(); | |||||
| } | |||||
| } | |||||
| } | |||||
| Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) { | |||||
| if (topo_order.size() == 0) { | |||||
| MS_LOG(ERROR) << "0 operator in costgraph."; | |||||
| return FAILED; | |||||
| } | |||||
| auto &first_op = topo_order[0]; | |||||
| if (first_op->prev_edges().size() > 0) { | |||||
| MS_LOG(ERROR) << "The first operator in the first of topological order of " | |||||
| "costgraph should have 0 incoming edge, but has " | |||||
| << first_op->prev_edges() << "edges."; | |||||
| return FAILED; | |||||
| } | |||||
| // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number | |||||
| // of the output of OperatorInfo that currently has not been used | |||||
| std::map<OperatorInfoPtr, int> curr_memory_state; | |||||
| (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size()))); | |||||
| std::map<OperatorInfoPtr, int> max_memory_state = curr_memory_state; | |||||
| // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has | |||||
| // not been used | |||||
| double curr_memory_size = first_op->GetOutputsTotalSize(); | |||||
| double max_memory_size = curr_memory_size; | |||||
| for (size_t finished = 1; finished < topo_order.size(); ++finished) { | |||||
| // Produce | |||||
| (void)curr_memory_state.emplace( | |||||
| std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size()))); | |||||
| curr_memory_size += topo_order[finished]->GetOutputsTotalSize(); | |||||
| // Consume | |||||
| for (const auto &prev_edge : topo_order[finished]->prev_edges()) { | |||||
| const auto &prev_op = prev_edge->prev_operator(); | |||||
| curr_memory_state[prev_op]--; | |||||
| } | |||||
| for (const auto &prev_edge : topo_order[finished]->prev_edges()) { | |||||
| const auto &prev_op = prev_edge->prev_operator(); | |||||
| if (curr_memory_state[prev_op] < 0) { | |||||
| MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op]; | |||||
| return FAILED; | |||||
| } else if (curr_memory_state[prev_op] == 0) { | |||||
| curr_memory_state.erase(prev_op); | |||||
| curr_memory_size -= prev_op->GetOutputsTotalSize(); | |||||
| } | |||||
| } | |||||
| if (curr_memory_size < 0) { | |||||
| MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size; | |||||
| } | |||||
| // Modify the max | |||||
| if (curr_memory_size > max_memory_size) { | |||||
| max_memory_size = curr_memory_size; | |||||
| max_memory_state = curr_memory_state; | |||||
| } | |||||
| } | |||||
| // Mark those critical operators | |||||
| MarkCriticalOpsAndEdges(max_memory_state); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CostGraph::ComputeOpsAndEdgesOutputCritical() { | |||||
| // Two steps to do: | |||||
| // 1. Compute a topological order of the costgraph | |||||
| // 2. Determine and mark the operators (and necessary edges) that are critical | |||||
| std::vector<OperatorInfoPtr> topo_order; | |||||
| TopologyOrder(&topo_order); | |||||
| std::reverse(std::begin(topo_order), std::end(topo_order)); | |||||
| if (DetermineCriticalOps(topo_order) != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Determining critical operators failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CostGraph::CalculateOpsMemoryCost() { | Status CostGraph::CalculateOpsMemoryCost() { | ||||
| for (auto &op : ops_) { | for (auto &op : ops_) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| @@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::CalculateOpsMemoryCostForInference() { | |||||
| for (auto &op : ops_) { | |||||
| MS_EXCEPTION_IF_NULL(op); | |||||
| if (op->CalculateMemoryCostForInference() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CostGraph::CalculateEdgesMemoryCost() { | Status CostGraph::CalculateEdgesMemoryCost() { | ||||
| for (auto &edge_pair : edges_) { | for (auto &edge_pair : edges_) { | ||||
| const auto &edges = edge_pair.second; | const auto &edges = edge_pair.second; | ||||
| @@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::CalculateEdgesMemoryCostForInference() { | |||||
| for (auto &edge_pair : edges_) { | |||||
| const auto &edges = edge_pair.second; | |||||
| for (auto &one_edge : edges) { | |||||
| if (one_edge->CalculateMemoryCostForInference() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { | OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { | ||||
| for (auto one_op : ops_) { | for (auto one_op : ops_) { | ||||
| if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { | if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { | ||||
| @@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::CalculateMemoryCost() { | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | |||||
| // Calculate operators' memory usage | |||||
| if (CalculateOpsMemoryCost() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed."; | |||||
| return FAILED; | |||||
| } | |||||
| // Calculate edges' memory usage | |||||
| if (CalculateEdgesMemoryCost() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed."; | |||||
| return FAILED; | |||||
| } | |||||
| // Correct memory usage caused by TmpIdentity | |||||
| if (CorrectOpsMemoryCost() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed."; | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Computing operators' parameter_involved failed."; | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| // inference phase | |||||
| if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) { | |||||
| // Calculate operators' memory usage | |||||
| if (CalculateOpsMemoryCostForInference() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; | |||||
| return FAILED; | |||||
| } | |||||
| // Calculate edges's memory usage | |||||
| if (CalculateEdgesMemoryCostForInference() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Computing operators' critical flag failed."; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -179,16 +179,24 @@ class CostGraph { | |||||
| void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, | void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, | ||||
| const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>, | const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>, | ||||
| CostPtrList &, CostPtrList &, CostPtrList *); | CostPtrList &, CostPtrList &, CostPtrList *); | ||||
| // Calculate memory cost for training phase or inference phase. | |||||
| Status CalculateMemoryCost(); | |||||
| // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | ||||
| // the memory cost can be resused. | |||||
| // the memory cost can be resused. This is used to calculate memory in the training phase. | |||||
| Status CalculateOpsMemoryCost(); | Status CalculateOpsMemoryCost(); | ||||
| // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | ||||
| // the memory cost can be resused. | |||||
| // the memory cost can be reused. This is used to calculate memory in the training phase. | |||||
| Status CalculateEdgesMemoryCost(); | Status CalculateEdgesMemoryCost(); | ||||
| // Calculate memory cost of operators in the inference phase. | |||||
| Status CalculateOpsMemoryCostForInference(); | |||||
| // Calculate memory cost of edges in the inference phase. | |||||
| Status CalculateEdgesMemoryCostForInference(); | |||||
| Status ComputeOpsAndEdgesParameterInvolved(); | Status ComputeOpsAndEdgesParameterInvolved(); | ||||
| // Compute for each operator whether the output is critical. | |||||
| Status ComputeOpsAndEdgesOutputCritical(); | |||||
| std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | ||||
| size_t GetNumPairs() const { return edges_.size(); } | |||||
| size_t GetNumEdges() const; | |||||
| Status InitSelectedStrategy(); | Status InitSelectedStrategy(); | ||||
| OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | ||||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | ||||
| @@ -208,6 +216,10 @@ class CostGraph { | |||||
| const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; } | const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; } | ||||
| private: | private: | ||||
| void TopologyOrder(std::vector<OperatorInfoPtr> *); | |||||
| void DFSForTopoOrder(const OperatorInfoPtr &, std::map<OperatorInfoPtr, bool> *, std::vector<OperatorInfoPtr> *); | |||||
| Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &); | |||||
| void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &); | |||||
| // Needed by rec_parser | // Needed by rec_parser | ||||
| std::vector<std::vector<std::string>> inputs_tensor_name_list_; | std::vector<std::vector<std::string>> inputs_tensor_name_list_; | ||||
| std::map<std::string, std::string> tuple_getitem_list_; | std::map<std::string, std::string> tuple_getitem_list_; | ||||
| @@ -37,6 +37,8 @@ void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_ | |||||
| outputs_type_lengths_ = output_lengths; | outputs_type_lengths_ = output_lengths; | ||||
| } | } | ||||
| void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; } | |||||
| double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs) const { | const std::vector<TensorInfo> &outputs) const { | ||||
| double result = 0.0; | double result = 0.0; | ||||
| @@ -63,6 +65,20 @@ double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | |||||
| return result; | return result; | ||||
| } | } | ||||
| double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &, | |||||
| const std::vector<TensorInfo> &outputs) const { | |||||
| double result = 0.0; | |||||
| if (is_outputs_critical_ == -1) { | |||||
| MS_LOG(EXCEPTION) << "The critical flag is not set."; | |||||
| } | |||||
| if (is_outputs_critical_ == 1) { | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]); | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | |||||
| // return the per device communication cost in the forward phase. | // return the per device communication cost in the forward phase. | ||||
| double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int32_t) const { | int32_t) const { | ||||
| @@ -70,6 +70,7 @@ class OperatorCost { | |||||
| void set_is_parameter(const std::vector<bool> &is_parameter); | void set_is_parameter(const std::vector<bool> &is_parameter); | ||||
| void set_is_parameter_involve(const std::vector<bool> &); | void set_is_parameter_involve(const std::vector<bool> &); | ||||
| void set_output_parameter_involve(int); | void set_output_parameter_involve(int); | ||||
| void set_output_critical(int); | |||||
| void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths); | void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths); | ||||
| std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; } | std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; } | ||||
| std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; } | std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; } | ||||
| @@ -92,6 +93,8 @@ class OperatorCost { | |||||
| // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), | // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), | ||||
| // plus necessary inputs. | // plus necessary inputs. | ||||
| virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | ||||
| // per device memory cost in a inference phase | |||||
| double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const; | |||||
| protected: | protected: | ||||
| // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of | // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of | ||||
| @@ -106,6 +109,9 @@ class OperatorCost { | |||||
| // for each input and output, the followings record the number of bytes of each element | // for each input and output, the followings record the number of bytes of each element | ||||
| std::vector<size_t> inputs_type_lengths_; | std::vector<size_t> inputs_type_lengths_; | ||||
| std::vector<size_t> outputs_type_lengths_; | std::vector<size_t> outputs_type_lengths_; | ||||
| // Whether the output is critical, which means that this output is included in calculating peak memory cost | |||||
| // in the inference phase. | |||||
| int is_outputs_critical_ = -1; | |||||
| }; | }; | ||||
| using OperatorCostPtr = std::shared_ptr<OperatorCost>; | using OperatorCostPtr = std::shared_ptr<OperatorCost>; | ||||
| @@ -1119,6 +1119,21 @@ Status OperatorInfo::CalculateMemoryCost() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OperatorInfo::CalculateMemoryCostForInference() { | |||||
| // First, set the 'is_outputs_critical_' flag into OperatorCost. | |||||
| if (is_output_critical_ == -1) { | |||||
| MS_LOG(EXCEPTION) << "The critical flag is not set."; | |||||
| return FAILED; | |||||
| } | |||||
| operator_cost()->set_output_critical(is_output_critical_); | |||||
| // Set the memory cost in the 'strategy_cost_' | |||||
| for (auto &swc : strategy_cost_) { | |||||
| auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr); | |||||
| swc->cost_list[0]->memory_with_reuse_ = mem_cost; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OperatorInfo::CorrectMemoryCost(size_t input_index) { | Status OperatorInfo::CorrectMemoryCost(size_t input_index) { | ||||
| for (auto &swc : strategy_cost_) { | for (auto &swc : strategy_cost_) { | ||||
| double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * | double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * | ||||
| @@ -1230,6 +1245,25 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &inpu | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| double OperatorInfo::GetOutputsTotalSize() { | |||||
| if (is_calculated_outputs_size_) { | |||||
| return outputs_total_size_; | |||||
| } | |||||
| if (outputs_type_lengths_.size() != outputs_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size() | |||||
| << " do not have the same number of outputs shape: " << outputs_shape_.size(); | |||||
| } | |||||
| double sum = 0.0; | |||||
| for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) { | |||||
| auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0), | |||||
| std::multiplies<double>()); | |||||
| sum += size * static_cast<double>(outputs_type_lengths_[i]); | |||||
| } | |||||
| is_calculated_outputs_size_ = true; | |||||
| outputs_total_size_ = sum; | |||||
| return outputs_total_size_; | |||||
| } | |||||
| Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) { | Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) { | ||||
| if (outputs_type.size() != outputs_shape_.size()) { | if (outputs_type.size() != outputs_shape_.size()) { | ||||
| MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() | MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() | ||||
| @@ -72,6 +72,7 @@ class OperatorInfo { | |||||
| Status set_is_parameter(const std::vector<bool> &is_parameter); | Status set_is_parameter(const std::vector<bool> &is_parameter); | ||||
| Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, | Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, | ||||
| const std::vector<size_t> &output_lengths); | const std::vector<size_t> &output_lengths); | ||||
| double GetOutputsTotalSize(); | |||||
| // Set outputs dtype. | // Set outputs dtype. | ||||
| // If only one output, outputs_type.size() is 1. | // If only one output, outputs_type.size() is 1. | ||||
| // If output is tuple, outputs_type.size() is greater than 1. | // If output is tuple, outputs_type.size() is greater than 1. | ||||
| @@ -96,9 +97,13 @@ class OperatorInfo { | |||||
| // is checked | // is checked | ||||
| Status SetCostUnderStrategyBase(const StrategyPtr &strategy); | Status SetCostUnderStrategyBase(const StrategyPtr &strategy); | ||||
| std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } | std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } | ||||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | |||||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | |||||
| // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving | |||||
| // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory | |||||
| // at the end of forward phase. | |||||
| Status CalculateMemoryCost(); | Status CalculateMemoryCost(); | ||||
| // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated | |||||
| // by the output | |||||
| Status CalculateMemoryCostForInference(); | |||||
| int ComputeOpAndPrevEdgeParameterInvolved(); | int ComputeOpAndPrevEdgeParameterInvolved(); | ||||
| ForwardOp forward_op() const { return forward_op_; } | ForwardOp forward_op() const { return forward_op_; } | ||||
| @@ -147,6 +152,9 @@ class OperatorInfo { | |||||
| // multiple times. This method is to correct this, and makes the cost is calulated only once. | // multiple times. This method is to correct this, and makes the cost is calulated only once. | ||||
| Status CorrectMemoryCost(size_t input_index); | Status CorrectMemoryCost(size_t input_index); | ||||
| int is_output_parameter_involve() const { return is_output_parameter_involve_; } | int is_output_parameter_involve() const { return is_output_parameter_involve_; } | ||||
| int is_output_critical() const { return is_output_critical_; } | |||||
| void mark_output_critical() { is_output_critical_ = 1; } | |||||
| void mark_output_not_critical() { is_output_critical_ = 0; } | |||||
| int used_devices() const { return used_devices_; } | int used_devices() const { return used_devices_; } | ||||
| // needed by rec_parser | // needed by rec_parser | ||||
| void set_type(const std::string &type) { type_ = type; } | void set_type(const std::string &type) { type_ = type; } | ||||
| @@ -220,7 +228,16 @@ class OperatorInfo { | |||||
| // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of | // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of | ||||
| // pre-operator that has parameters as input. | // pre-operator that has parameters as input. | ||||
| std::vector<bool> is_parameter_involve_; | std::vector<bool> is_parameter_involve_; | ||||
| int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | |||||
| // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating | |||||
| // peak memory cost in the training phase. | |||||
| // -1: unset; 0: not parameter_involved; 1: parameter_involved | |||||
| int is_output_parameter_involve_ = -1; | |||||
| // Whether this output is critical, which means that this output is included in calculating peak memory cost | |||||
| // in the inference phase. | |||||
| // -1 : unset; 0: not critical; 1: critical | |||||
| int is_output_critical_ = -1; | |||||
| double outputs_total_size_ = 0.0; | |||||
| bool is_calculated_outputs_size_ = false; | |||||
| // for each input and output, the followings record the number of bytes of each element | // for each input and output, the followings record the number of bytes of each element | ||||
| std::vector<size_t> inputs_type_lengths_; | std::vector<size_t> inputs_type_lengths_; | ||||
| std::vector<size_t> outputs_type_lengths_; | std::vector<size_t> outputs_type_lengths_; | ||||
| @@ -1055,6 +1055,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| // Step 1: Traverse the ANF graph, and create NODEs for costgraph: | // Step 1: Traverse the ANF graph, and create NODEs for costgraph: | ||||
| // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies | // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies | ||||
| // for each OperatorInfo; | // for each OperatorInfo; | ||||
| // Step 1.1: Deal with 'Reshape': | |||||
| // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's | |||||
| // layout as its output layout. | |||||
| // Step 2: Traverse the ANF graph, and create EDGES for costgraph: | // Step 2: Traverse the ANF graph, and create EDGES for costgraph: | ||||
| // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies | // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies | ||||
| // for each edge, based on the strategies of two OperatorInfos; | // for each edge, based on the strategies of two OperatorInfos; | ||||
| @@ -1062,7 +1065,8 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity | // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity | ||||
| // operator for this Parameter, and add an edge for the use of this Parameter by each | // operator for this Parameter, and add an edge for the use of this Parameter by each | ||||
| // subsequent operator; | // subsequent operator; | ||||
| // Step 3.1: Calculate memory usage | |||||
| // Step 3.1: Calculate memory usage: | |||||
| // note the memory usage calculation is different in training phase and inference phase. | |||||
| // Step 4: Run the Dynamic Programming algorithm: | // Step 4: Run the Dynamic Programming algorithm: | ||||
| // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge | // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge | ||||
| // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input | // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input | ||||
| @@ -1087,35 +1091,21 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | ||||
| } | } | ||||
| } | } | ||||
| // reshape operator needs the next node's input_layout as its output_layout. | |||||
| // and needs the previous node's output_layout as its input_layout. | |||||
| // Step 1.1 | |||||
| ReshapeCostCompute(all_nodes); | ReshapeCostCompute(all_nodes); | ||||
| // Step 2 | // Step 2 | ||||
| ConstructCostGraphEdges(all_nodes); | ConstructCostGraphEdges(all_nodes); | ||||
| MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | ||||
| << " operators, and " << entire_costgraph->GetNumPairs() << " edges.", | |||||
| << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; | |||||
| // Step 3: Augment the costgraph. | |||||
| AugmentCostGraph(all_nodes); | |||||
| // Step 3: Augment the costgraph. | |||||
| AugmentCostGraph(all_nodes); | |||||
| MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() | MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() | ||||
| << " operators, and " << entire_costgraph->GetNumPairs() << " edges."; | |||||
| << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; | |||||
| // Step 3.1: Calculate the memory usage | // Step 3.1: Calculate the memory usage | ||||
| if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | |||||
| // Calculate operators' memory usage | |||||
| if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed."; | |||||
| } | |||||
| // Calculate edges' memory usage | |||||
| if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed."; | |||||
| } | |||||
| // Correct memory usage caused by TmpIdentity | |||||
| if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed."; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed."; | |||||
| if (entire_costgraph->CalculateMemoryCost() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Calculating memory cost failed."; | |||||
| } | } | ||||
| // Step 4: run DP algorithm on the costgraph. | // Step 4: run DP algorithm on the costgraph. | ||||
| @@ -32,5 +32,6 @@ def test_inference_phase(): | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | train_network = TrainOneStepCell(net_with_loss, optimizer) | ||||
| train_network.set_train() | train_network.set_train() | ||||
| train_network.set_auto_parallel() | |||||
| output = train_network(predict, label) | output = train_network(predict, label) | ||||