|
|
|
@@ -67,13 +67,13 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input |
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
|
int64_t size_i = 1; |
|
|
|
for (size_t j = 0; j < shape_i.size(); j++) { |
|
|
|
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i); |
|
|
|
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j])); |
|
|
|
} |
|
|
|
size_t type_byte = GetTypeByte(type_ptr); |
|
|
|
if (type_byte == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); |
|
|
|
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); |
|
|
|
input_size_list->push_back(LongToSize(size_i)); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -99,13 +99,13 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A |
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
|
int64_t size_i = 1; |
|
|
|
for (size_t j = 0; j < shape_i.size(); j++) { |
|
|
|
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i); |
|
|
|
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j])); |
|
|
|
} |
|
|
|
size_t type_byte = GetTypeByte(type_ptr); |
|
|
|
if (type_byte == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); |
|
|
|
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); |
|
|
|
output_size_list.push_back(LongToSize(size_i)); |
|
|
|
} |
|
|
|
kernel_mod_ptr->SetOutputSizeList(output_size_list); |
|
|
|
|