| @@ -92,7 +92,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data | |||||
| } | } | ||||
| int64_t cin_ori = c; | int64_t cin_ori = c; | ||||
| int64_t cout_ori = n / groups; | int64_t cout_ori = n / groups; | ||||
| int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16; | |||||
| int64_t cube_k = data_type == DT_INT8 ? 32 : 16; | |||||
| int64_t e_mult = std::min( | int64_t e_mult = std::min( | ||||
| Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), | Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)), | ||||
| groups); | groups); | ||||
| @@ -100,8 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data | |||||
| int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; | int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN; | ||||
| int64_t c1_dim = cin_opt / cube_k; | int64_t c1_dim = cin_opt / cube_k; | ||||
| int64_t g_dim = Ceil(groups, e_mult); | int64_t g_dim = Ceil(groups, e_mult); | ||||
| auto n1 = Ceil(n, kCubeN); | |||||
| auto n1 = Ceil(n , kCubeN); | |||||
| dst_shape.clear(); | dst_shape.clear(); | ||||
| dst_shape.push_back(g_dim * c1_dim * h * w); | dst_shape.push_back(g_dim * c1_dim * h * w); | ||||
| dst_shape.push_back(n1); | dst_shape.push_back(n1); | ||||
| @@ -26,11 +26,11 @@ namespace formats { | |||||
| class FormatTransferFractalZ : public FormatTransfer { | class FormatTransferFractalZ : public FormatTransfer { | ||||
| public: | public: | ||||
| Status TransFormat(const TransArgs &args, TransResult &result) override; | Status TransFormat(const TransArgs &args, TransResult &result) override; | ||||
| Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups) override; | |||||
| Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups); | |||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | ||||
| std::vector<int64_t> &dst_shape) override; | std::vector<int64_t> &dst_shape) override; | ||||
| Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | ||||
| std::vector<int64_t> &dst_shape, int64_t groups) override; | |||||
| std::vector<int64_t> &dst_shape, int64_t groups); | |||||
| }; | }; | ||||
| } // namespace formats | } // namespace formats | ||||
| } // namespace ge | } // namespace ge | ||||