|
|
@@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input |
|
|
} else { |
|
|
} else { |
|
|
auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); |
|
|
auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); |
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
int size_i = 1; |
|
|
|
|
|
|
|
|
int64_t size_i = 1; |
|
|
for (size_t j = 0; j < shape_i.size(); j++) { |
|
|
for (size_t j = 0; j < shape_i.size(); j++) { |
|
|
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i); |
|
|
|
|
|
|
|
|
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j])); |
|
|
} |
|
|
} |
|
|
size_t type_byte = GetTypeByte(type_ptr); |
|
|
size_t type_byte = GetTypeByte(type_ptr); |
|
|
if (type_byte == 0) { |
|
|
if (type_byte == 0) { |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); |
|
|
|
|
|
input_size_list->push_back(IntToSize(size_i)); |
|
|
|
|
|
|
|
|
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); |
|
|
|
|
|
input_size_list->push_back(LongToSize(size_i)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
@@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A |
|
|
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); |
|
|
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); |
|
|
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); |
|
|
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); |
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
int size_i = 1; |
|
|
|
|
|
|
|
|
int64_t size_i = 1; |
|
|
for (size_t j = 0; j < shape_i.size(); j++) { |
|
|
for (size_t j = 0; j < shape_i.size(); j++) { |
|
|
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i); |
|
|
|
|
|
|
|
|
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j])); |
|
|
} |
|
|
} |
|
|
size_t type_byte = GetTypeByte(type_ptr); |
|
|
size_t type_byte = GetTypeByte(type_ptr); |
|
|
if (type_byte == 0) { |
|
|
if (type_byte == 0) { |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i); |
|
|
|
|
|
output_size_list.push_back(IntToSize(size_i)); |
|
|
|
|
|
|
|
|
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); |
|
|
|
|
|
output_size_list.push_back(LongToSize(size_i)); |
|
|
} |
|
|
} |
|
|
kernel_mod_ptr->SetOutputSizeList(output_size_list); |
|
|
kernel_mod_ptr->SetOutputSizeList(output_size_list); |
|
|
return true; |
|
|
return true; |
|
|
|