| @@ -81,13 +81,20 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape, | |||||
| int64_t groups) { | |||||
| Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, | |||||
| std::vector<int64_t> &dst_shape, int64_t groups) { | |||||
| auto c0 = GetCubeSizeByDataType(data_type); | auto c0 = GetCubeSizeByDataType(data_type); | ||||
| if (c0 < 0) { | if (c0 < 0) { | ||||
| return ACL_ERROR_GE_DATATYPE_INVALID; | return ACL_ERROR_GE_DATATYPE_INVALID; | ||||
| } | } | ||||
| int64_t cin_ori = c; | int64_t cin_ori = c; | ||||
| if (groups == 0) { | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param]Failed, groups must not be equal 0, " | |||||
| "and current groups is %ld ", groups); | |||||
| REPORT_CALL_ERROR("E19999", "Check graph param failed, groups must not be equal 0," | |||||
| "and groups are %ld", groups); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int64_t cout_ori = n / groups; | int64_t cout_ori = n / groups; | ||||
| int64_t cube_k = GetCubeSizeByDataType(data_type); | int64_t cube_k = GetCubeSizeByDataType(data_type); | ||||
| int64_t e_mult = std::min( | int64_t e_mult = std::min( | ||||
| @@ -100,7 +107,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data | |||||
| dst_shape.clear(); | dst_shape.clear(); | ||||
| dst_shape.push_back(g_dim * c1_dim * h * w); | dst_shape.push_back(g_dim * c1_dim * h * w); | ||||
| dst_shape.push_back(n1); | dst_shape.push_back(n1); | ||||
| dst_shape.push_back(16); | |||||
| dst_shape.push_back(kNiSize); | |||||
| dst_shape.push_back(cube_k); | dst_shape.push_back(cube_k); | ||||
| if (!IsShapeValid(dst_shape)) { | if (!IsShapeValid(dst_shape)) { | ||||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, dst shape %s", | GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, dst shape %s", | ||||