|
|
|
@@ -67,18 +67,17 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t |
|
|
|
|
|
|
|
bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_t> &shape, size_t *size) { |
|
|
|
MS_EXCEPTION_IF_NULL(size); |
|
|
|
int tmp_size = 1; |
|
|
|
size_t tmp_size = 1; |
|
|
|
uint32_t type_size = 4; |
|
|
|
for (size_t i = 0; i < shape.size(); i++) { |
|
|
|
IntMulWithOverflowCheck(tmp_size, SizeToInt(shape[i]), &tmp_size); |
|
|
|
tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]); |
|
|
|
} |
|
|
|
|
|
|
|
if (!GetHcomTypeSize(data_type, &type_size)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
IntMulWithOverflowCheck(tmp_size, UintToInt(type_size), &tmp_size); |
|
|
|
*size = IntToSize(tmp_size); |
|
|
|
*size = SizetMulWithOverflowCheck(tmp_size, type_size); |
|
|
|
|
|
|
|
MS_LOG(INFO) << "size[" << *size << "]"; |
|
|
|
return true; |
|
|
|
|