Browse Source

add fractal zn lstm

pull/1502/head
dongduo5@huawei.com 4 years ago
parent
commit
7841c4d196
2 changed files with 14 additions and 14 deletions
  1. +14
    -10
      ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc
  2. +0
    -4
      ge/common/formats/utils/formats_definitions.h

+ 14
- 10
ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc View File

@@ -34,19 +34,17 @@ bool CheckDataTypeSupported(const DataType &data_type) {

Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int64_t> &src_shape,
std::vector<int64_t> &dst_shape) {
/*auto cube_size = GetCubeSizeByDataType(data_type);
auto cube_size = GetCubeSizeByDataType(data_type);
dst_shape.clear();
dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size)));
dst_shape.push_back(src_shape.at(kHwcnH));
dst_shape.push_back(src_shape.at(kHwcnW));
dst_shape.push_back(src_shape.at(kHwcnN));
dst_shape.push_back(Ceil(src_shape.at(kHwcnN), static_cast<int64_t>(cube_size)));
dst_shape.push_back(cube_size);
dst_shape.push_back(cube_size);
if (!CheckShapeValid(dst_shape, kFracZnLstmDimsNum)) {
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
}*/
}
return SUCCESS;
}

@@ -59,7 +57,7 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) {
return UNSUPPORTED;
}
if (!CheckDataTypeSupported(args.src_data_type)) {
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to C1HWNCoC0, invalid data type %s",
GELOGE(UNSUPPORTED, "Failed to trans shape from HWCN to FRACTAL_ZN_LSTM, invalid data type %s",
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str());
return UNSUPPORTED;
}
@@ -89,7 +87,12 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) {
}

Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) {

auto ret = memcpy_s(dst.get(), static_cast<size_t>(total_size), args.data, static_cast<size_t>(total_size));
result.data = dst;
result.length = static_cast<size_t>(total_size);
return SUCCESS;
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>());
if (dst == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s",
@@ -102,9 +105,10 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in
auto w = args.src_shape.at(kHwcnW);
auto c = args.src_shape.at(kHwcnC);
auto n = args.src_shape.at(kHwcnN);
auto c1 = args.dst_shape.at(kC1hwncoc0C1);
auto c0 = args.dst_shape.at(kC1hwncoc0C0);
auto co = args.dst_shape.at(kC1hwncoc0Co);
auto hwc1 = args.dst_shape.at(kFracZHWC1);
auto n0 = args.dst_shape.at(kFracZN0);
auto ni = args.dst_shape.at(kFracZNi);
auto c0 = args.dst_shape.at(kFracZC0);
int64_t coc0 = co * c0;
int64_t ncoc0 = n * coc0;
int64_t wncoc0 = w * ncoc0;


+ 0
- 4
ge/common/formats/utils/formats_definitions.h View File

@@ -101,10 +101,6 @@ enum DhwncDimIndex {
kDhwncDimsNum
};

enum FracZnLstmIndex {
kFracZnLstmDimsNum = 6,
};

} // namespace formats
} // namespace ge
#endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_

Loading…
Cancel
Save