Browse Source

add fractal zn index

pull/1502/head
dongduo5@huawei.com 4 years ago
parent
commit
db8bf4e589
2 changed files with 7 additions and 2 deletions
  1. +2
    -2
      ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc
  2. +5
    -0
      ge/common/formats/utils/formats_definitions.h

+ 2
- 2
ge/common/formats/format_transfers/format_transfer_hwcn_fractal_zn_lstm.cc View File

@@ -42,7 +42,7 @@ Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int
dst_shape.push_back(src_shape.at(kHwcnN)); dst_shape.push_back(src_shape.at(kHwcnN));
dst_shape.push_back(cube_size); dst_shape.push_back(cube_size);
dst_shape.push_back(cube_size); dst_shape.push_back(cube_size);
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) {
if (!CheckShapeValid(dst_shape, kFracZnLstmDimsNum)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str()); ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID; return ACL_ERROR_GE_SHAPE_INVALID;
@@ -67,7 +67,7 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) {
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }
if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) {
if (!CheckShapeValid(args.dst_shape, kFracZnLstmDimsNum)) {
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }


+ 5
- 0
ge/common/formats/utils/formats_definitions.h View File

@@ -100,6 +100,11 @@ enum DhwncDimIndex {
kDhwncC, kDhwncC,
kDhwncDimsNum kDhwncDimsNum
}; };

enum FracZnLstmIndex {
kFracZnLstmDimsNum = 6,
};

} // namespace formats } // namespace formats
} // namespace ge } // namespace ge
#endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_

Loading…
Cancel
Save