|
|
@@ -126,7 +126,8 @@ std::vector<T> MatrixMultiply(const T A[], const T B[], int M, int N, int K) { |
|
|
|
|
|
|
|
|
template <typename SRC_T, typename DST_T> |
|
|
template <typename SRC_T, typename DST_T> |
|
|
void ConvertConvWeight4DTo7D(void *src, void *dst, size_t CO, size_t KH, size_t KW, size_t CI, size_t OGroup = 1, |
|
|
void ConvertConvWeight4DTo7D(void *src, void *dst, size_t CO, size_t KH, size_t KW, size_t CI, size_t OGroup = 1, |
|
|
size_t CI_TILE = 4, size_t CO_TILE = 4) { |
|
|
|
|
|
|
|
|
const size_t CI_TILE = 4, const size_t CO_TILE = 4) { |
|
|
|
|
|
if (CO_TILE == 0 || CI_TILE == 0) return; |
|
|
auto origin_weight = reinterpret_cast<SRC_T *>(src); |
|
|
auto origin_weight = reinterpret_cast<SRC_T *>(src); |
|
|
auto packed_weight = reinterpret_cast<DST_T *>(dst); |
|
|
auto packed_weight = reinterpret_cast<DST_T *>(dst); |
|
|
auto CI_SLICES = UP_DIV(CI, CI_TILE); |
|
|
auto CI_SLICES = UP_DIV(CI, CI_TILE); |
|
|
|