Browse Source

!272 Increase the max tensor size

Merge pull request !272 from jiangzhenguang/Increase-the-max-size-of-tensor
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3faac4d1be
3 changed files with 24 additions and 13 deletions
  1. +8
    -8
      mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc
  2. +1
    -1
      mindspore/ccsrc/operator/prim_structures.cc
  3. +15
    -4
      mindspore/ccsrc/utils/convert_utils_base.h

+ 8
- 8
mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc View File

@@ -65,16 +65,16 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
} else {
auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1;
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
input_size_list->push_back(IntToSize(size_i));
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
input_size_list->push_back(LongToSize(size_i));
}
}
return true;
@@ -97,16 +97,16 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int size_i = 1;
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
IntMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
IntMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
output_size_list.push_back(IntToSize(size_i));
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
output_size_list.push_back(LongToSize(size_i));
}
kernel_mod_ptr->SetOutputSizeList(output_size_list);
return true;


+ 1
- 1
mindspore/ccsrc/operator/prim_structures.cc View File

@@ -582,7 +582,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
int result = 1;
for (size_t i = 0; i < shpx_data.size(); i++) {
int value = GetValue<int>(shpx_data[i]);
IntMulWithOverflowCheck(result, value, &result);
result = IntMulWithOverflowCheck(result, value);
}

auto result_v = MakeValue(result);


+ 15
- 4
mindspore/ccsrc/utils/convert_utils_base.h View File

@@ -91,15 +91,26 @@ inline unsigned int UlongToUint(size_t u) {
return static_cast<unsigned int>(u);
}

inline void IntMulWithOverflowCheck(int a, int b, int *c) {
inline int IntMulWithOverflowCheck(int a, int b) {
int out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
return out;
}

inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) {
int64_t out = a * b;
if (a != 0) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
return out;
}

inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {


Loading…
Cancel
Save