|
|
|
@@ -19,6 +19,7 @@ |
|
|
|
#include <vector> |
|
|
|
#include <memory> |
|
|
|
#include <utility> |
|
|
|
#include <string> |
|
|
|
#include "tools/converter/converter_flags.h" |
|
|
|
#include "backend/optimizer/common/pass.h" |
|
|
|
#include "include/errorcode.h" |
|
|
|
@@ -40,11 +41,54 @@ class SlicePreposePass : public Pass { |
|
|
|
void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } |
|
|
|
|
|
|
|
private: |
|
|
|
schema::SliceT *GetSliceT(const CNodePtr &cnode); |
|
|
|
bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode); |
|
|
|
void ClearCNodeAbstractValue(const CNodePtr &cnode); |
|
|
|
STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode, |
|
|
|
const int index, const TransactionPtr &tr = nullptr); |
|
|
|
ValueNodePtr CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector<int32_t> &axes, |
|
|
|
const std::vector<int32_t> &begin, const std::vector<int32_t> &size); |
|
|
|
ValueNodePtr CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode); |
|
|
|
CNodePtr InsertSlice(const FuncGraphPtr &graph, const ValueNodePtr &slice_vnode, const CNodePtr &preceed_cnode, |
|
|
|
const int index, const TransactionPtr &tr); |
|
|
|
STATUS VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim = -1); |
|
|
|
STATUS SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int32_t> &ref_shape, |
|
|
|
std::vector<int32_t> *axes, std::vector<int32_t> *begin, std::vector<int32_t> *size); |
|
|
|
CNodePtr CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape, |
|
|
|
const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode); |
|
|
|
bool SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list, |
|
|
|
const std::vector<int32_t> &ref_shape = {}); |
|
|
|
int GetReshapeAbnormalAxeIn(const std::vector<int> &shape_in, const std::vector<int> &shape_out, |
|
|
|
std::vector<int> *mapped_axe); |
|
|
|
int GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector<int> &mapped_axe, |
|
|
|
const std::vector<int> &shape_out, std::vector<int> *shape_out_copy, |
|
|
|
bool *is_normal_mode, bool *support_abnormal_mode); |
|
|
|
bool PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode, |
|
|
|
const std::vector<int> &shape_in, const std::vector<int> &shape_out_copy, |
|
|
|
const std::vector<int> &mapped_axe); |
|
|
|
CNodePtr CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, |
|
|
|
const CNodePtr &matmul_cnode, const std::vector<int> &shape_in, |
|
|
|
const int abnormal_axe_in, const int count_sliced_axe_in, |
|
|
|
const bool slice_at_front); |
|
|
|
CNodePtr CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, |
|
|
|
const CNodePtr &new_reshape1_cnode, const std::vector<int64_t> &new_shape1, |
|
|
|
const int abnormal_axe_in, const int count_sliced_axe_in, |
|
|
|
const int count_sliced2, const bool slice_at_front); |
|
|
|
bool PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode, |
|
|
|
const CNodePtr &matmul_cnode, const std::vector<int> &shape_in, |
|
|
|
const std::vector<int> &shape_out, const int abnormal_axe_in, |
|
|
|
const int abnormal_index_out); |
|
|
|
bool GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs, |
|
|
|
std::vector<std::vector<int32_t>> *shapes, std::vector<bool> *is_default_params); |
|
|
|
|
|
|
|
bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode); |
|
|
|
|
|
|
|
bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode); |
|
|
|
bool PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode); |
|
|
|
bool PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &matmul_cnode); |
|
|
|
bool PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &fc_cnode); |
|
|
|
bool PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &transpose_cnode); |
|
|
|
bool PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &arithmetic_cnode); |
|
|
|
bool MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode, const CNodePtr &slice2_cnode); |
|
|
|
bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices); |
|
|
|
|
|
|
|
private: |
|
|
|
FmkType fmk_type = lite::converter::FmkType_ONNX; |
|
|
|
|