Browse Source

Transdata

pull/1211/head
zk 4 years ago
parent
commit
70b4500345
2 changed files with 4 additions and 5 deletions
  1. +2
    -3
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc
  2. +2
    -2
      ge/common/formats/format_transfers/format_transfer_fractal_z.h

+ 2
- 3
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -92,7 +92,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data
}
int64_t cin_ori = c;
int64_t cout_ori = n / groups;
int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16;
int64_t cube_k = data_type == DT_INT8 ? 32 : 16;
int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)),
groups);
@@ -100,8 +100,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data
int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN;
int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = Ceil(groups, e_mult);
auto n1 = Ceil(n, kCubeN);

auto n1 = Ceil(n , kCubeN);
dst_shape.clear();
dst_shape.push_back(g_dim * c1_dim * h * w);
dst_shape.push_back(n1);


+ 2
- 2
ge/common/formats/format_transfers/format_transfer_fractal_z.h View File

@@ -26,11 +26,11 @@ namespace formats {
class FormatTransferFractalZ : public FormatTransfer {
public:
Status TransFormat(const TransArgs &args, TransResult &result) override;
Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups) override;
Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups);
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape, int64_t groups) override;
std::vector<int64_t> &dst_shape, int64_t groups);
};
} // namespace formats
} // namespace ge


Loading…
Cancel
Save