|
|
@@ -134,59 +134,6 @@ Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) { |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status RangeInfo::InferNewAttr() { |
|
|
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
|
|
int64_t rank = g_device_manager->rank_index_in_stage(); |
|
|
|
|
|
|
|
|
|
|
|
// If repeated calculation and repeated num as the last dimension of dev-matrix, |
|
|
|
|
|
// the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_ |
|
|
|
|
|
// are repeated calculation, and these rank have the same 'new_start_'. |
|
|
|
|
|
// If repeated calculation and repeated num as the first dimension of dev-matrix, |
|
|
|
|
|
// the dev-matrix is [repeated_calc_num_, split_num_], so rank 0 and rank split_num_ and so on |
|
|
|
|
|
// are repeated calculation, and these rank have the same 'new_start_'. |
|
|
|
|
|
float start_bias = inputs_shape_[0][0] / split_num_ * delta_; |
|
|
|
|
|
if (repeated_num_in_dev_matrix_right_) { |
|
|
|
|
|
new_start_ = start_ + start_bias * (rank / repeated_calc_num_); |
|
|
|
|
|
} else { |
|
|
|
|
|
new_start_ = start_ + start_bias * (rank % split_num_); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
new_limit_ = new_start_ + start_bias; |
|
|
|
|
|
MS_LOG(INFO) << name_ << ": The new start is " << new_start_ << ", the new limit is " << new_limit_; |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status RangeInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
|
|
GenerateGraph gen_g = GenerateGraph(); |
|
|
|
|
|
if (gen_g.Init(cnode) != SUCCESS) { |
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": GenerateGraph Init failed"; |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (InferNewAttr() != SUCCESS) { |
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer new attr failed"; |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Attr attr_start = std::make_pair(START, MakeValue(new_start_)); |
|
|
|
|
|
Attr attr_limit = std::make_pair(LIMIT, MakeValue(new_limit_)); |
|
|
|
|
|
Attr attr_delta = std::make_pair(DELTA, MakeValue(delta_)); |
|
|
|
|
|
OperatorAttrs attrs = {attr_start, attr_limit, attr_delta}; |
|
|
|
|
|
auto new_range_op = gen_g.PushBack({gen_g.NewOpInst(RANGE, attrs), gen_g.virtual_input_node()}); |
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(new_range_op, 1)}; |
|
|
|
|
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( |
|
|
|
|
|
std::make_pair(input_nodes, new_range_op)); |
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ReplaceGraphPtr RangeInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
|
|
if (ComputeReplaceGraph(cnode) != SUCCESS) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; |
|
|
|
|
|
} |
|
|
|
|
|
return replace_graph_; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status RangeInfo::GenerateStrategies(int64_t stage_id) { |
|
|
Status RangeInfo::GenerateStrategies(int64_t stage_id) { |
|
|
Shape input0_split(inputs_shape_[0].size(), 1); |
|
|
Shape input0_split(inputs_shape_[0].size(), 1); |
|
|
Shapes splittable_inputs = {input0_split}; |
|
|
Shapes splittable_inputs = {input0_split}; |
|
|
|