|
|
|
@@ -43,9 +43,9 @@ Status TransShapeHwcnToFrazlstm(const DataType &data_type, const std::vector<int |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
dst_shape.push_back(cube_size); |
|
|
|
if (!CheckShapeValid(dst_shape, kFracZDimsNum)) { |
|
|
|
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check dst shape %s", |
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check dst shape %s", |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
@@ -198,15 +198,15 @@ Status FormatTransferHwcnFractalznlstm::TransShape(Format src_format, const std: |
|
|
|
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { |
|
|
|
if (src_format == FORMAT_HWCN && CheckDataTypeSupported(data_type)) { |
|
|
|
if (!CheckShapeValid(src_shape, kHwcnDimsNum)) { |
|
|
|
GELOGE(ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID, "Failed to check src shape %s", |
|
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to check src shape %s", |
|
|
|
ShapeToString(src_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_SHAPE_INVALID; |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
return TransShapeHwcnToFrazlstm(data_type, src_shape, dst_shape); |
|
|
|
} else if (src_format != FORMAT_HWCN) { |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_FORMAT_INVALID; |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} else { |
|
|
|
return ACL_ERROR_GE_TRANSSHAPE_DATATYPE_INVALID; |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|