Browse Source

increase the max size of tensor.

tags/v0.6.0-beta
jzg 5 years ago
parent
commit
fb90ff164b
3 changed files with 13 additions and 13 deletions
  1. +4
    -4
      mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc
  2. +1
    -1
      mindspore/ccsrc/operator/prim_structures.cc
  3. +8
    -8
      mindspore/ccsrc/utils/convert_utils_base.h

+ 4
- 4
mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc View File

@@ -67,13 +67,13 @@ bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
input_size_list->push_back(LongToSize(size_i));
}
}
@@ -99,13 +99,13 @@ bool SetIOSize(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<A
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]), &size_i);
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
LongMulWithOverflowCheck(size_i, SizeToInt(type_byte), &size_i);
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
output_size_list.push_back(LongToSize(size_i));
}
kernel_mod_ptr->SetOutputSizeList(output_size_list);


+ 1
- 1
mindspore/ccsrc/operator/prim_structures.cc View File

@@ -587,7 +587,7 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
int result = 1;
for (size_t i = 0; i < shpx_data.size(); i++) {
int value = GetValue<int>(shpx_data[i]);
IntMulWithOverflowCheck(result, value, &result);
result = IntMulWithOverflowCheck(result, value);
}

auto result_v = MakeValue(result);


+ 8
- 8
mindspore/ccsrc/utils/convert_utils_base.h View File

@@ -91,26 +91,26 @@ inline unsigned int UlongToUint(size_t u) {
return static_cast<unsigned int>(u);
}

inline void IntMulWithOverflowCheck(int a, int b, int *c) {
inline int IntMulWithOverflowCheck(int a, int b) {
int out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
return out;
}

inline void LongMulWithOverflowCheck(int64_t a, int64_t b, int64_t *c) {
inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) {
int64_t out = a * b;
if (a != 0) {
bool ok = ((out / a) != b);
if (ok) {
bool overflow = ((out / a) != b);
if (overflow) {
MS_LOG(EXCEPTION) << "Mul: a(" << a << ") * b(" << b << ") result is overflow";
}
}
*c = out;
return out;
}

inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) {


Loading…
Cancel
Save