|
|
|
@@ -42,7 +42,7 @@ Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int |
|
|
|
dst_shape.push_back(src_shape.at(kHwcnN)); |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { |
|
|
|
if (!CheckShapeValid(dst_shape, kFracZnLstmDimsNum)) { |
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
@@ -67,7 +67,7 @@ Status CheckArgsForHwcnToFrazlstm(const TransArgs &args) { |
|
|
|
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(args.src_shape).c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (!CheckShapeValid(args.dst_shape, kFracZDimsNum)) { |
|
|
|
if (!CheckShapeValid(args.dst_shape, kFracZnLstmDimsNum)) { |
|
|
|
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(args.dst_shape).c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|