|
|
|
@@ -34,19 +34,17 @@ bool CheckDataTypeSupported(const DataType &data_type) { |
|
|
|
|
|
|
|
Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int64_t> &src_shape, |
|
|
|
std::vector<int64_t> &dst_shape) { |
|
|
|
/*auto cube_size = GetCubeSizeByDataType(data_type); |
|
|
|
auto cube_size = GetCubeSizeByDataType(data_type); |
|
|
|
dst_shape.clear(); |
|
|
|
dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size))); |
|
|
|
dst_shape.push_back(src_shape.at(kHwcnH)); |
|
|
|
dst_shape.push_back(src_shape.at(kHwcnW)); |
|
|
|
dst_shape.push_back(src_shape.at(kHwcnN)); |
|
|
|
dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast<int64_t>(cube_size))); |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
if (!CheckShapeValid(dst_shape, kFracZnLstmDimsNum)) { |
|
|
|
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { |
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
}*/ |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -59,7 +57,7 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { |
|
|
|
return UNSUPPORTED; |
|
|
|
} |
|
|
|
if (!CheckDataTypeSupported(args.src_data_type)) { |
|
|
|
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", |
|
|
|
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s", |
|
|
|
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); |
|
|
|
return UNSUPPORTED; |
|
|
|
} |
|
|
|
@@ -89,7 +87,12 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { |
|
|
|
} |
|
|
|
|
|
|
|
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { |
|
|
|
|
|
|
|
auto ret = memcpy_s(dst.get(), static_cast<size_t>(total_size), args.data, static_cast<size_t>(total_size)); |
|
|
|
result.data = dst; |
|
|
|
result.length = static_cast<size_t>(total_size); |
|
|
|
return SUCCESS; |
|
|
|
|
|
|
|
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); |
|
|
|
if (dst == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", |
|
|
|
@@ -102,9 +105,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in |
|
|
|
auto w = args.src_shape.at(kHwcnW); |
|
|
|
auto c = args.src_shape.at(kHwcnC); |
|
|
|
auto n = args.src_shape.at(kHwcnN); |
|
|
|
auto c1 = args.dst_shape.at(kC1hwncoc0C1); |
|
|
|
auto c0 = args.dst_shape.at(kC1hwncoc0C0); |
|
|
|
auto co = args.dst_shape.at(kC1hwncoc0Co); |
|
|
|
auto hwc1 = args.dst_shape.at(kFracZHWC1); |
|
|
|
auto n0 = args.dst_shape.at(kFracZN0); |
|
|
|
auto ni = args.dst_shape.at(kFracZNi); |
|
|
|
auto c0 = args.dst_shape.at(kFracZC0); |
|
|
|
int64_t coc0 = co * c0; |
|
|
|
int64_t ncoc0 = n * coc0; |
|
|
|
int64_t wncoc0 = w * ncoc0; |
|
|
|
|