Browse Source

trandata

pull/1211/head
zk 4 years ago
parent
commit
27b678d81e
1 changed files with 30 additions and 31 deletions
  1. +30
    -31
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc

+ 30
- 31
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -65,7 +65,7 @@ Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_
Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) {
auto c0 = GetCubeSizeByDataType(data_type);
if (c0 < 0) {
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
return ACL_ERROR_GE_DATATYPE_INVALID;
}

auto c1 = Ceil(c, c0);
@@ -77,9 +77,9 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_
dst_shape.push_back(kNiSize);
dst_shape.push_back(c0);
if (!IsShapeValid(dst_shape)) {
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s",
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
}
@@ -88,7 +88,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data
, int64_t groups) {
auto c0 = GetCubeSizeByDataType(data_type);
if (c0 < 0) {
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID;
return ACL_ERROR_GE_DATATYPE_INVALID;
}
int64_t cin_ori = c;
int64_t cout_ori = n / groups;
@@ -106,9 +106,9 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data
dst_shape.push_back(16);
dst_shape.push_back(cube_k);
if (!IsShapeValid(dst_shape)) {
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s",
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
}
@@ -127,21 +127,20 @@ Status TransShapeNchwToFz(const std::vector<int64_t> &src_shape, DataType data_t

Status TransShapeHwcnToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) {
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}

auto h = src_shape.at(kHwcnH);
auto w = src_shape.at(kHwcnW);
auto c = src_shape.at(kHwcnC);
auto n = src_shape.at(kHwcnN);

return TransShapeToFz(n, c, h, w, data_type, dst_shape);
}

Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape
, int64_t groups){
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) {
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}

auto h = src_shape.at(kHwcnH);
@@ -155,7 +154,7 @@ Status TransShapeHwcnToFzWithGroups(const std::vector<int64_t> &src_shape, DataT

Status TransShapeNhwcToFz(const std::vector<int64_t> &src_shape, DataType data_type, std::vector<int64_t> &dst_shape) {
if (!CheckShapeValid(src_shape, kNhwcDimsNum)) {
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}

auto n = src_shape.at(kNhwcN);
@@ -194,10 +193,10 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;);
return ACL_ERROR_GE_MEMORY_ALLOCATION;);

for (int64_t vfi = 0; vfi < vf_cnt; vfi++) {
// vertical fractal matrix base index
@@ -230,8 +229,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
if (protected_size < size) {
std::string error = "Failed to operate the dst memory, protected_size is " +
FmtToStr(protected_size) + " and size is " + FmtToStr(size);
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str());
return INTERNAL_ERROR;
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str());
return ACL_ERROR_GE_PARAM_INVALID;
}
char *dst_data = reinterpret_cast<char *>(dst.get() + offset);
const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size);
@@ -240,9 +239,9 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) {
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset,
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset,
ret, need_pad_zero);
return INTERNAL_ERROR;
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}
}
}
@@ -285,10 +284,10 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[size_output_data], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), size_output_data);
return OUT_OF_MEMORY;);
return ACL_ERROR_GE_MEMORY_ALLOCATION;);
for (int64_t g = 0; g < groups; g++) {
for (int64_t d = 0; d < kDim; d++) {
for (int64_t c = 0; c < c_dim; c++) {
@@ -352,10 +351,10 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;);
return ACL_ERROR_GE_MEMORY_ALLOCATION;);

for (int64_t c1i = 0; c1i < c1; c1i++) {
for (int64_t hi = 0; hi < h; hi++) {
@@ -386,9 +385,9 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) {
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
dst_offset, ret, pad_zero);
return INTERNAL_ERROR;
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}
}
}
@@ -427,10 +426,10 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dst == nullptr,
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld",
TypeUtils::FormatToSerialString(args.src_format).c_str(),
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size);
return OUT_OF_MEMORY;);
return ACL_ERROR_GE_MEMORY_ALLOCATION;);

for (int64_t c1i = 0; c1i < c1; c1i++) {
for (int64_t hi = 0; hi < h; hi++) {
@@ -449,9 +448,9 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
static_cast<size_t>(data_size));
} else {
if (protected_size < data_size) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to operate the dst memory, protected_size is %ld and size is %ld",
protected_size, data_size);
return INTERNAL_ERROR;
return ACL_ERROR_GE_PARAM_INVALID;
}
int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i);
char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset);
@@ -461,9 +460,9 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) {
}
}
if (ret != EOK) {
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d",
dst_offset, ret, pad_zero);
return INTERNAL_ERROR;
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}
}
}
@@ -489,7 +488,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
return ret;
}
if (!IsTransShapeDstCorrect(args, expect_shape)) {
return PARAM_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}
return TransFormatHwcnToFzWithGroups(args, result, groups);
}
@@ -513,7 +512,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
return ret;
}
if (!IsTransShapeDstCorrect(args, expect_shape)) {
return PARAM_INVALID;
return ACL_ERROR_GE_SHAPE_INVALID;
}

if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) {
@@ -527,7 +526,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r
if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) {
return TransFormatFromNchwToFz(args, result);
}
return UNSUPPORTED;
return ACL_ERROR_GE_FORMAT_INVALID;
}

Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type,


Loading…
Cancel
Save