| @@ -53,7 +53,7 @@ struct Cost { | |||
| communication_redis_backward_ = 0.0; | |||
| communication_forward_ = 0.0; | |||
| } | |||
| // 'memory_with_reuse_' calculates the peak memory usage in a training phase | |||
| // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase | |||
| double memory_with_reuse_; | |||
| // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated | |||
| // by ONLY forward phase | |||
| @@ -300,5 +300,20 @@ Status Edge::CalculateMemoryCost() { | |||
| return SUCCESS; | |||
| } | |||
| Status Edge::CalculateMemoryCostForInference() { | |||
| // Currently, memory cost is NOT calculated for redistribution | |||
| if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) { | |||
| MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_; | |||
| return FAILED; | |||
| } | |||
| for (auto &cost_kv : cost_map_) { | |||
| auto &cost_v = cost_kv.second; | |||
| if (!cost_v.empty()) { | |||
| cost_v[0]->memory_with_reuse_ = 0; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -131,9 +131,13 @@ class Edge { | |||
| void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } | |||
| const CostPtr &selected_cost() const { return selected_cost_; } | |||
| void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } | |||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | |||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | |||
| // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving | |||
| // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory | |||
| // at the end of forward phase. | |||
| Status CalculateMemoryCost(); | |||
| // In the inference phase, | |||
| Status CalculateMemoryCostForInference(); | |||
| void mark_output_critical() { is_output_critical_ = 1; } | |||
| private: | |||
| std::string edge_name_; | |||
| @@ -156,7 +160,11 @@ class Edge { | |||
| // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. | |||
| bool is_identity_edge; | |||
| CostPtr selected_cost_; | |||
| // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator | |||
| // is parameter-involved | |||
| int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | |||
| // In the inference phase, this is used to mark whether the output of the previous operator is critical. | |||
| int is_output_critical_ = 0; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, | |||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | |||
| << ", communication_cost_: " << ret->communication_cost_ | |||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | |||
| MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; | |||
| MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; | |||
| for (size_t i = 1; i < after_mem_filter.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | |||
| MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | |||
| @@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d | |||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | |||
| << ", communication_cost_: " << ret->communication_cost_ | |||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | |||
| MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; | |||
| MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; | |||
| for (size_t i = 1; i < after_mem_filter.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | |||
| MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | |||
| @@ -1351,6 +1351,14 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||
| return succ_edges; | |||
| } | |||
| size_t CostGraph::GetNumEdges() const { | |||
| size_t sum = 0; | |||
| for (const auto &kv : edges_) { | |||
| auto &edges = kv.second; | |||
| sum += edges.size(); | |||
| } | |||
| return sum; | |||
| } | |||
| Status CostGraph::InitSelectedStrategy() { | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| @@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | |||
| return SUCCESS; | |||
| } | |||
| void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited, | |||
| std::vector<OperatorInfoPtr> *topo_order) { | |||
| MS_EXCEPTION_IF_NULL(current_op); | |||
| MS_EXCEPTION_IF_NULL(visited); | |||
| MS_EXCEPTION_IF_NULL(topo_order); | |||
| visited->at(current_op) = true; | |||
| for (const auto &s_edge : current_op->succ_edges()) { | |||
| if (!visited->at(s_edge->next_operator())) { | |||
| DFSForTopoOrder(s_edge->next_operator(), visited, topo_order); | |||
| } | |||
| } | |||
| topo_order->push_back(current_op); | |||
| } | |||
| // Compute a topological order of the costgraph | |||
| void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) { | |||
| std::map<OperatorInfoPtr, bool> visited; | |||
| for (auto &op : ops_) { | |||
| visited[op] = false; | |||
| } | |||
| for (auto &op : ops_) { | |||
| if (!visited[op]) { | |||
| DFSForTopoOrder(op, &visited, topo_order); | |||
| } | |||
| } | |||
| } | |||
| void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &candidate_ops) { | |||
| for (auto &op : ops_) { | |||
| auto search = candidate_ops.find(op); | |||
| if (search != candidate_ops.end()) { | |||
| // Mark the critical operators | |||
| op->mark_output_critical(); | |||
| // Mark the successive edges | |||
| for (auto &s_edge : op->succ_edges()) { | |||
| s_edge->mark_output_critical(); | |||
| } | |||
| } else { | |||
| op->mark_output_not_critical(); | |||
| } | |||
| } | |||
| } | |||
| Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) { | |||
| if (topo_order.size() == 0) { | |||
| MS_LOG(ERROR) << "0 operator in costgraph."; | |||
| return FAILED; | |||
| } | |||
| auto &first_op = topo_order[0]; | |||
| if (first_op->prev_edges().size() > 0) { | |||
| MS_LOG(ERROR) << "The first operator in the first of topological order of " | |||
| "costgraph should have 0 incoming edge, but has " | |||
| << first_op->prev_edges() << "edges."; | |||
| return FAILED; | |||
| } | |||
| // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number | |||
| // of the output of OperatorInfo that currently has not been used | |||
| std::map<OperatorInfoPtr, int> curr_memory_state; | |||
| (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size()))); | |||
| std::map<OperatorInfoPtr, int> max_memory_state = curr_memory_state; | |||
| // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has | |||
| // not been used | |||
| double curr_memory_size = first_op->GetOutputsTotalSize(); | |||
| double max_memory_size = curr_memory_size; | |||
| for (size_t finished = 1; finished < topo_order.size(); ++finished) { | |||
| // Produce | |||
| (void)curr_memory_state.emplace( | |||
| std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size()))); | |||
| curr_memory_size += topo_order[finished]->GetOutputsTotalSize(); | |||
| // Consume | |||
| for (const auto &prev_edge : topo_order[finished]->prev_edges()) { | |||
| const auto &prev_op = prev_edge->prev_operator(); | |||
| curr_memory_state[prev_op]--; | |||
| } | |||
| for (const auto &prev_edge : topo_order[finished]->prev_edges()) { | |||
| const auto &prev_op = prev_edge->prev_operator(); | |||
| if (curr_memory_state[prev_op] < 0) { | |||
| MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op]; | |||
| return FAILED; | |||
| } else if (curr_memory_state[prev_op] == 0) { | |||
| curr_memory_state.erase(prev_op); | |||
| curr_memory_size -= prev_op->GetOutputsTotalSize(); | |||
| } | |||
| } | |||
| if (curr_memory_size < 0) { | |||
| MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size; | |||
| } | |||
| // Modify the max | |||
| if (curr_memory_size > max_memory_size) { | |||
| max_memory_size = curr_memory_size; | |||
| max_memory_state = curr_memory_state; | |||
| } | |||
| } | |||
| // Mark those critical operators | |||
| MarkCriticalOpsAndEdges(max_memory_state); | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::ComputeOpsAndEdgesOutputCritical() { | |||
| // Two steps to do: | |||
| // 1. Compute a topological order of the costgraph | |||
| // 2. Determine and mark the operators (and necessary edges) that are critical | |||
| std::vector<OperatorInfoPtr> topo_order; | |||
| TopologyOrder(&topo_order); | |||
| std::reverse(std::begin(topo_order), std::end(topo_order)); | |||
| if (DetermineCriticalOps(topo_order) != SUCCESS) { | |||
| MS_LOG(ERROR) << "Determining critical operators failed."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::CalculateOpsMemoryCost() { | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| @@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() { | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::CalculateOpsMemoryCostForInference() { | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| if (op->CalculateMemoryCostForInference() != SUCCESS) { | |||
| MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::CalculateEdgesMemoryCost() { | |||
| for (auto &edge_pair : edges_) { | |||
| const auto &edges = edge_pair.second; | |||
| @@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() { | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::CalculateEdgesMemoryCostForInference() { | |||
| for (auto &edge_pair : edges_) { | |||
| const auto &edges = edge_pair.second; | |||
| for (auto &one_edge : edges) { | |||
| if (one_edge->CalculateMemoryCostForInference() != SUCCESS) { | |||
| MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { | |||
| for (auto one_op : ops_) { | |||
| if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { | |||
| @@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::CalculateMemoryCost() { | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | |||
| // Calculate operators' memory usage | |||
| if (CalculateOpsMemoryCost() != SUCCESS) { | |||
| MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed."; | |||
| return FAILED; | |||
| } | |||
| // Calculate edges' memory usage | |||
| if (CalculateEdgesMemoryCost() != SUCCESS) { | |||
| MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed."; | |||
| return FAILED; | |||
| } | |||
| // Correct memory usage caused by TmpIdentity | |||
| if (CorrectOpsMemoryCost() != SUCCESS) { | |||
| MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed."; | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Computing operators' parameter_involved failed."; | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| // inference phase | |||
| if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) { | |||
| // Calculate operators' memory usage | |||
| if (CalculateOpsMemoryCostForInference() != SUCCESS) { | |||
| MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; | |||
| return FAILED; | |||
| } | |||
| // Calculate edges's memory usage | |||
| if (CalculateEdgesMemoryCostForInference() != SUCCESS) { | |||
| MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Computing operators' critical flag failed."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -179,16 +179,24 @@ class CostGraph { | |||
| void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, | |||
| const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>, | |||
| CostPtrList &, CostPtrList &, CostPtrList *); | |||
| // Calculate memory cost for training phase or inference phase. | |||
| Status CalculateMemoryCost(); | |||
| // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | |||
| // the memory cost can be resused. | |||
| // the memory cost can be resused. This is used to calculate memory in the training phase. | |||
| Status CalculateOpsMemoryCost(); | |||
| // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | |||
| // the memory cost can be resused. | |||
| // the memory cost can be reused. This is used to calculate memory in the training phase. | |||
| Status CalculateEdgesMemoryCost(); | |||
| // Calculate memory cost of operators in the inference phase. | |||
| Status CalculateOpsMemoryCostForInference(); | |||
| // Calculate memory cost of edges in the inference phase. | |||
| Status CalculateEdgesMemoryCostForInference(); | |||
| Status ComputeOpsAndEdgesParameterInvolved(); | |||
| // Compute for each operator whether the output is critical. | |||
| Status ComputeOpsAndEdgesOutputCritical(); | |||
| std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | |||
| size_t GetNumPairs() const { return edges_.size(); } | |||
| size_t GetNumEdges() const; | |||
| Status InitSelectedStrategy(); | |||
| OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | |||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | |||
| @@ -208,6 +216,10 @@ class CostGraph { | |||
| const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; } | |||
| private: | |||
| void TopologyOrder(std::vector<OperatorInfoPtr> *); | |||
| void DFSForTopoOrder(const OperatorInfoPtr &, std::map<OperatorInfoPtr, bool> *, std::vector<OperatorInfoPtr> *); | |||
| Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &); | |||
| void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &); | |||
| // Needed by rec_parser | |||
| std::vector<std::vector<std::string>> inputs_tensor_name_list_; | |||
| std::map<std::string, std::string> tuple_getitem_list_; | |||
| @@ -37,6 +37,8 @@ void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_ | |||
| outputs_type_lengths_ = output_lengths; | |||
| } | |||
| void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; } | |||
| double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs) const { | |||
| double result = 0.0; | |||
| @@ -63,6 +65,20 @@ double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | |||
| return result; | |||
| } | |||
| double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &, | |||
| const std::vector<TensorInfo> &outputs) const { | |||
| double result = 0.0; | |||
| if (is_outputs_critical_ == -1) { | |||
| MS_LOG(EXCEPTION) << "The critical flag is not set."; | |||
| } | |||
| if (is_outputs_critical_ == 1) { | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| // return the per device communication cost in the forward phase. | |||
| double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t) const { | |||
| @@ -70,6 +70,7 @@ class OperatorCost { | |||
| void set_is_parameter(const std::vector<bool> &is_parameter); | |||
| void set_is_parameter_involve(const std::vector<bool> &); | |||
| void set_output_parameter_involve(int); | |||
| void set_output_critical(int); | |||
| void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths); | |||
| std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; } | |||
| std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; } | |||
| @@ -92,6 +93,8 @@ class OperatorCost { | |||
| // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), | |||
| // plus necessary inputs. | |||
| virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | |||
| // per device memory cost in a inference phase | |||
| double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const; | |||
| protected: | |||
| // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of | |||
| @@ -106,6 +109,9 @@ class OperatorCost { | |||
| // for each input and output, the followings record the number of bytes of each element | |||
| std::vector<size_t> inputs_type_lengths_; | |||
| std::vector<size_t> outputs_type_lengths_; | |||
| // Whether the output is critical, which means that this output is included in calculating peak memory cost | |||
| // in the inference phase. | |||
| int is_outputs_critical_ = -1; | |||
| }; | |||
| using OperatorCostPtr = std::shared_ptr<OperatorCost>; | |||
| @@ -1119,6 +1119,21 @@ Status OperatorInfo::CalculateMemoryCost() { | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::CalculateMemoryCostForInference() { | |||
| // First, set the 'is_outputs_critical_' flag into OperatorCost. | |||
| if (is_output_critical_ == -1) { | |||
| MS_LOG(EXCEPTION) << "The critical flag is not set."; | |||
| return FAILED; | |||
| } | |||
| operator_cost()->set_output_critical(is_output_critical_); | |||
| // Set the memory cost in the 'strategy_cost_' | |||
| for (auto &swc : strategy_cost_) { | |||
| auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr); | |||
| swc->cost_list[0]->memory_with_reuse_ = mem_cost; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::CorrectMemoryCost(size_t input_index) { | |||
| for (auto &swc : strategy_cost_) { | |||
| double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * | |||
| @@ -1230,6 +1245,25 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &inpu | |||
| return SUCCESS; | |||
| } | |||
| double OperatorInfo::GetOutputsTotalSize() { | |||
| if (is_calculated_outputs_size_) { | |||
| return outputs_total_size_; | |||
| } | |||
| if (outputs_type_lengths_.size() != outputs_shape_.size()) { | |||
| MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size() | |||
| << " do not have the same number of outputs shape: " << outputs_shape_.size(); | |||
| } | |||
| double sum = 0.0; | |||
| for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) { | |||
| auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0), | |||
| std::multiplies<double>()); | |||
| sum += size * static_cast<double>(outputs_type_lengths_[i]); | |||
| } | |||
| is_calculated_outputs_size_ = true; | |||
| outputs_total_size_ = sum; | |||
| return outputs_total_size_; | |||
| } | |||
| Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) { | |||
| if (outputs_type.size() != outputs_shape_.size()) { | |||
| MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() | |||
| @@ -72,6 +72,7 @@ class OperatorInfo { | |||
| Status set_is_parameter(const std::vector<bool> &is_parameter); | |||
| Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, | |||
| const std::vector<size_t> &output_lengths); | |||
| double GetOutputsTotalSize(); | |||
| // Set outputs dtype. | |||
| // If only one output, outputs_type.size() is 1. | |||
| // If output is tuple, outputs_type.size() is greater than 1. | |||
| @@ -96,9 +97,13 @@ class OperatorInfo { | |||
| // is checked | |||
| Status SetCostUnderStrategyBase(const StrategyPtr &strategy); | |||
| std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } | |||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | |||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | |||
| // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving | |||
| // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory | |||
| // at the end of forward phase. | |||
| Status CalculateMemoryCost(); | |||
| // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated | |||
| // by the output | |||
| Status CalculateMemoryCostForInference(); | |||
| int ComputeOpAndPrevEdgeParameterInvolved(); | |||
| ForwardOp forward_op() const { return forward_op_; } | |||
| @@ -147,6 +152,9 @@ class OperatorInfo { | |||
| // multiple times. This method is to correct this, and makes the cost is calulated only once. | |||
| Status CorrectMemoryCost(size_t input_index); | |||
| int is_output_parameter_involve() const { return is_output_parameter_involve_; } | |||
| int is_output_critical() const { return is_output_critical_; } | |||
| void mark_output_critical() { is_output_critical_ = 1; } | |||
| void mark_output_not_critical() { is_output_critical_ = 0; } | |||
| int used_devices() const { return used_devices_; } | |||
| // needed by rec_parser | |||
| void set_type(const std::string &type) { type_ = type; } | |||
| @@ -220,7 +228,16 @@ class OperatorInfo { | |||
| // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of | |||
| // pre-operator that has parameters as input. | |||
| std::vector<bool> is_parameter_involve_; | |||
| int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | |||
| // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating | |||
| // peak memory cost in the training phase. | |||
| // -1: unset; 0: not parameter_involved; 1: parameter_involved | |||
| int is_output_parameter_involve_ = -1; | |||
| // Whether this output is critical, which means that this output is included in calculating peak memory cost | |||
| // in the inference phase. | |||
| // -1 : unset; 0: not critical; 1: critical | |||
| int is_output_critical_ = -1; | |||
| double outputs_total_size_ = 0.0; | |||
| bool is_calculated_outputs_size_ = false; | |||
| // for each input and output, the followings record the number of bytes of each element | |||
| std::vector<size_t> inputs_type_lengths_; | |||
| std::vector<size_t> outputs_type_lengths_; | |||
| @@ -1055,6 +1055,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| // Step 1: Traverse the ANF graph, and create NODEs for costgraph: | |||
| // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies | |||
| // for each OperatorInfo; | |||
| // Step 1.1: Deal with 'Reshape': | |||
| // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's | |||
| // layout as its output layout. | |||
| // Step 2: Traverse the ANF graph, and create EDGES for costgraph: | |||
| // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies | |||
| // for each edge, based on the strategies of two OperatorInfos; | |||
| @@ -1062,7 +1065,8 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity | |||
| // operator for this Parameter, and add an edge for the use of this Parameter by each | |||
| // subsequent operator; | |||
| // Step 3.1: Calculate memory usage | |||
| // Step 3.1: Calculate memory usage: | |||
| // note the memory usage calculation is different in training phase and inference phase. | |||
| // Step 4: Run the Dynamic Programming algorithm: | |||
| // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge | |||
| // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input | |||
| @@ -1087,35 +1091,21 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||
| } | |||
| } | |||
| // reshape operator needs the next node's input_layout as its output_layout. | |||
| // and needs the previous node's output_layout as its input_layout. | |||
| // Step 1.1 | |||
| ReshapeCostCompute(all_nodes); | |||
| // Step 2 | |||
| ConstructCostGraphEdges(all_nodes); | |||
| MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | |||
| << " operators, and " << entire_costgraph->GetNumPairs() << " edges.", | |||
| << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; | |||
| // Step 3: Augment the costgraph. | |||
| AugmentCostGraph(all_nodes); | |||
| // Step 3: Augment the costgraph. | |||
| AugmentCostGraph(all_nodes); | |||
| MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() | |||
| << " operators, and " << entire_costgraph->GetNumPairs() << " edges."; | |||
| << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; | |||
| // Step 3.1: Calculate the memory usage | |||
| if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | |||
| // Calculate operators' memory usage | |||
| if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed."; | |||
| } | |||
| // Calculate edges' memory usage | |||
| if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed."; | |||
| } | |||
| // Correct memory usage caused by TmpIdentity | |||
| if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed."; | |||
| if (entire_costgraph->CalculateMemoryCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Calculating memory cost failed."; | |||
| } | |||
| // Step 4: run DP algorithm on the costgraph. | |||
| @@ -32,5 +32,6 @@ def test_inference_phase(): | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| train_network.set_train() | |||
| train_network.set_auto_parallel() | |||
| output = train_network(predict, label) | |||