|
|
|
@@ -92,7 +92,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data |
|
|
|
} |
|
|
|
int64_t cin_ori = c; |
|
|
|
int64_t cout_ori = n / groups; |
|
|
|
int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; |
|
|
|
int64_t cube_k = data_type == DT_INT8 ? 32 : 16; |
|
|
|
int64_t e_mult = std::min( |
|
|
|
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), |
|
|
|
groups); |
|
|
|
@@ -100,8 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data |
|
|
|
int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; |
|
|
|
int64_t c1_dim = cin_opt / cube_k; |
|
|
|
int64_t g_dim = Ceil(groups, e_mult); |
|
|
|
auto n1 = Ceil(n, kCubeN); |
|
|
|
|
|
|
|
auto n1 = Ceil(n , kCubeN); |
|
|
|
dst_shape.clear(); |
|
|
|
dst_shape.push_back(g_dim * c1_dim * h * w); |
|
|
|
dst_shape.push_back(n1); |
|
|
|
|