|
|
|
@@ -39,7 +39,8 @@ using OperatorList = std::vector<OperatorC>; |
|
|
|
class RedistributionOperatorInfer { |
|
|
|
public: |
|
|
|
const int NONE = -1; |
|
|
|
explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} |
|
|
|
explicit RedistributionOperatorInfer(bool construct_op_flag = true) |
|
|
|
: construct_op_flag_(construct_op_flag), is_cost_model_(false) {} |
|
|
|
Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, |
|
|
|
bool is_cost_model = false); |
|
|
|
~RedistributionOperatorInfer() = default; |
|
|
|
|