diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc index b046df1f..12dd57e0 100644 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc @@ -34,19 +34,17 @@ bool CheckDataTypeSupported(const DataType &data_type) { Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector &src_shape, std::vector &dst_shape) { - /*auto cube_size = GetCubeSizeByDataType(data_type); + auto cube_size = GetCubeSizeByDataType(data_type); dst_shape.clear(); dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast(cube_size))); - dst_shape.push_back(src_shape.at(kHwcnH)); - dst_shape.push_back(src_shape.at(kHwcnW)); - dst_shape.push_back(src_shape.at(kHwcnN)); + dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast(cube_size))); dst_shape.push_back(cube_size); dst_shape.push_back(cube_size); - if (!CheckShapeValid(dst_shape, kFracZnLstmDimsNum)) { + if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; - }*/ + } return SUCCESS; } @@ -59,7 +57,7 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { return UNSUPPORTED; } if (!CheckDataTypeSupported(args.src_data_type)) { - GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s", + GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s", TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); return UNSUPPORTED; } @@ -89,7 +87,12 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { } Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { + + auto ret = memcpy_s(dst.get(), static_cast(total_size), args.data, static_cast(total_size)); + result.data = dst; + result.length = static_cast(total_size); return SUCCESS; + std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", @@ -102,9 +105,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto w = args.src_shape.at(kHwcnW); auto c = args.src_shape.at(kHwcnC); auto n = args.src_shape.at(kHwcnN); - auto c1 = args.dst_shape.at(kC1hwncoc0C1); - auto c0 = args.dst_shape.at(kC1hwncoc0C0); - auto co = args.dst_shape.at(kC1hwncoc0Co); + auto hwc1 = args.dst_shape.at(kFracZHWC1); + auto n0 = args.dst_shape.at(kFracZN0); + auto ni = args.dst_shape.at(kFracZNi); + auto c0 = args.dst_shape.at(kFracZC0); int64_t coc0 = co * c0; int64_t ncoc0 = n * coc0; int64_t wncoc0 = w * ncoc0; diff --git a/ge/common/formats/utils/formats_definitions.h b/ge/common/formats/utils/formats_definitions.h index 01430a9d..62ead019 100755 --- a/ge/common/formats/utils/formats_definitions.h +++ b/ge/common/formats/utils/formats_definitions.h @@ -101,10 +101,6 @@ enum DhwncDimIndex { kDhwncDimsNum }; -enum FracZnLstmIndex { - kFracZnLstmDimsNum = 6, -}; - } // namespace formats } // namespace ge #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_