Browse Source

Pre Merge pull request !1211 from zhukun2020/ge_master

pull/1211/MERGE
zhukun2020 Gitee 4 years ago
parent
commit
d4fe59a22a
4 changed files with 165 additions and 3 deletions
  1. +135
    -1
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc
  2. +3
    -0
      ge/common/formats/format_transfers/format_transfer_fractal_z.h
  3. +12
    -2
      ge/host_kernels/transdata_kernel.cc
  4. +15
    -0
      tests/ut/ge/common/format_transfer_fracz_hwcn_unittest.cc

+ 135
- 1
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -29,6 +29,29 @@
namespace ge {
namespace formats {
namespace {
constexpr int64_t kCubeN = 16;
constexpr int64_t kGroupNum = 1;
constexpr int64_t kDim = 1;

static int64_t Measure(int64_t x, int64_t y) {
int64_t z = y;
while (x % y != 0) {
z = x % y;
x = y;
y = z;
}
return z;
}

// least common multiple
static int64_t Lcm(int64_t a, int64_t b) {
if (b == 0) {
return -1;
}
int64_t temp = (a * b) / (Measure(a, b));
return temp;
}

Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; }

/**
@@ -82,10 +105,24 @@ Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_t
auto w = src_shape.at(kHwcnW);
auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN);

return TransShapeToFz(n, c, h, w, data_type, dst_shape);
}

Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape
, int64_t groups){
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID;
}

auto h = src_shape.at(kHwcnH);
auto w = src_shape.at(kHwcnW);
auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN);

return TransShapeToFzWithGroups(n, c, h, w, data_type, dst_shape, groups);
}


Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) {
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) {
return ACL_ERROR_GE_SHAPE_INVALID;
@@ -187,6 +224,78 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
return SUCCESS;
}

Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,int64_t groups){
int64_t h_dim = args.src_shape[kHwcnH];
int64_t w_dim = args.src_shape[kHwcnW];
int64_t c_dim = args.src_shape[kHwcnC];
int64_t n_dim = args.src_shape[kHwcnN];
int64_t cin_ori = c_dim;
int64_t cout_ori = n_dim / groups;
if (cin_ori == 0 || cout_ori == 0) {
GELOGE(GRAPH_FAILED,
"Cin_ori, cout_ori must not be equal 0, "
"and current cin_ori, cout_ori, groups are %d %d %d",
cin_ori, cout_ori, groups);
return GRAPH_FAILED;
}
const int64_t cube_k = args.src_data_type == DT_INT8 ? 32 : 16;
int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)),
groups);
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN;
int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = Ceil(groups, e_mult);
int64_t dim_cin = cin_opt / cube_k;
int64_t data_size = GetCubeSizeByDataType(args.src_data_type);
int64_t size_output_data =
g_dim * kDim * dim_cin * h_dim * w_dim * cout_opt * cube_k * data_size;
GE_CHK_BOOL_EXEC_NOLOG(size_output_data != 0, result.length = static_cast<size_t>(size_output_data);
return SUCCESS;);
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data);
return ACL_ERROR_GE_MEMORY_ALLOCATION;);
for (int64_t g = 0; g < groups; g++) {
for (int64_t d = 0; d < kDim; d++) {
for (int64_t c = 0; c < c_dim; c++) {
for (int64_t h = 0; h < h_dim; h++) {
for (int64_t w = 0; w < w_dim; w++) {
for (int64_t n = 0; n < cout_ori; n++) {
int64_t e_val = g % e_mult;
int64_t dst_ci = e_val * cin_ori + c;
int64_t dst_co = e_val * cout_ori + n;
int64_t src_co = g * cout_ori + n;
int64_t tempory = dst_ci % cube_k;
int64_t srx_inx = 0;
int64_t dst_inx =
(g / e_mult) * kDim * c1_dim * h_dim * w_dim * cout_opt *
cube_k +
d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
(dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k +
h * w_dim * cout_opt * cube_k + w * cout_opt * cube_k +
dst_co * cube_k + tempory;
srx_inx = d * h_dim * w_dim * c_dim * n_dim +
h * w_dim * c_dim * n_dim + w * c_dim * n_dim +
c * n_dim + src_co;
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_inx * data_size);
const char *src_data = reinterpret_cast<const char *>(args.data + srx_inx * data_size);
for (int64_t index = 0; index < data_size; index++) {
*dst_data++ = *src_data++;
}
}
}
}
}
}
}
result.data = dst;
result.length = static_cast<size_t>(size_output_data);
return SUCCESS;
}
Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
int64_t h = args.src_shape[kHwcnH];
int64_t w = args.src_shape[kHwcnW];
@@ -338,6 +447,31 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
}
} // namespace

Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result,
int64_t groups) {
GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(),
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str());
std::vector<int64_t> expect_shape;
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expect_shape, groups);
if (ret != SUCCESS) {
return ret;
}
if (!IsTransShapeDstCorrect(args, expect_shape)) {
return ACL_ERROR_GE_SHAPE_INVALID;
}
return TransFormatHwcnToFzWithGroups(args, result, groups);
}

Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,
Format dst_format, std::vector<int64_t> &dst_shape, int64_t groups){
if (CheckDataTypeSupport(data_type) != SUCCESS) {
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
}
return TransShapeHwcnToFzWithGroups(src_shape, data_type, dst_shape, groups);
}

Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &result) {
GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s",
TypeUtils::FormatToSerialString(args.src_format).c_str(),


+ 3
- 0
ge/common/formats/format_transfers/format_transfer_fractal_z.h View File

@@ -26,8 +26,11 @@ namespace formats {
class FormatTransferFractalZ : public FormatTransfer {
public:
Status TransFormat(const TransArgs &args, TransResult &result) override;
Status TransFormat(const TransArgs &args, TransResult &result, int64_t groups);
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape) override;
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format,
std::vector<int64_t> &dst_shape, int64_t groups);
};
} // namespace formats
} // namespace ge


+ 12
- 2
ge/host_kernels/transdata_kernel.cc View File

@@ -82,16 +82,17 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<C
const auto &data_shape = op_desc->GetShape().GetDims();
const auto &data_format = op_desc->GetFormat();
const auto &data_type = op_desc->GetDataType();
const in64_t groups = op_desc_ptr->GetAttr("groups", groups)
GELOGD(
"current node %s, format %s, input shape %s, data type %s, weight format %s, shape %s, data type %s. "
"output format %s, shape %s, data type %s",
"output format %s, shape %s, data type %s, groups %d",
op_desc_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(src_format).c_str(),
formats::ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
TypeUtils::FormatToSerialString(const_weight_ptr->GetTensorDesc().GetFormat()).c_str(),
formats::ShapeToString(const_weight_ptr->GetTensorDesc().GetShape()).c_str(),
TypeUtils::DataTypeToSerialString(const_weight_ptr->GetTensorDesc().GetDataType()).c_str(),
TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());
TypeUtils::DataTypeToSerialString(data_type).c_str(), groups);

const uint8_t *src_data = const_weight_ptr->GetData().data();
const formats::TransArgs trans_args{src_data, src_format, data_format, src_shape, data_shape, src_data_type};
@@ -113,6 +114,15 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<C
GELOGI("CheckSize failed, input size is not equal to weight size");
return NOT_CHANGED;
}
if((src_format == FOMAT_HWCN) && (data_format == FORMAT_FRACTAL_Z)) {
if (formats::TransFormat(trans_args, trans_result , groups) != SUCCESS) {
GELOGW("Failed to trans formats from %s to %s, shape %s to %s, data type %s",
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(),
formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(),
TypeUtils::DataTypeToSerialString(src_data_type).c_str());
return NOT_CHANGED;
}
}
if (formats::TransFormat(trans_args, trans_result) != SUCCESS) {
GELOGW("Failed to trans formats from %s to %s, shape %s to %s, data type %s",
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(),


+ 15
- 0
tests/ut/ge/common/format_transfer_fracz_hwcn_unittest.cc View File

@@ -6897,5 +6897,20 @@ TEST_F(UtestFormatTransferFracZHwcn, fp16_1c_1n_pad_cn) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(result.data.get()))[i], ret[i]);
}
}
TEST_F(UtestFormatTransferFracZHwcn, fracz_to_hwcn_fp16_success_with_groups) {
uint16_t data_4d[2 * 2 * 2 * 2] = {1};
uint16_t data[1 * 1 * 1 * 2 *2 * 16 *16] = {1024};
int64_t groups = 1;
TransArgs args{
reinterpret_cast<uint8_t *>(data_4d), FORMAT_FRACTAL_Z, FORMAT_HWCN, {1, 1, 1, 2, 2, 16, 6}, {2, 2, 2, 2}, DT_FLOAT16 ,groups};
TransResult result;

FormatTransferFracZHwcn transfer;
EXPECT_EQ(transfer.TransFormat(args, result), SUCCESS);
EXPECT_EQ(result.length, sizeof(data));
for (int i = 0; i < sizeof(data) / sizeof(data[0]); ++i) {
EXPECT_EQ((reinterpret_cast<uint16_t *>(result.data.get()))[i], data[i]);
}
}
} // namespace formats
} // namespace ge

Loading…
Cancel
Save