Browse Source

check memcpy return code

tags/v1.4.0
hesham 4 years ago
parent
commit
f7009083c6
3 changed files with 27 additions and 28 deletions
  1. +22
    -27
      mindspore/ccsrc/minddata/dataset/core/tensor.cc
  2. +2
    -1
      mindspore/ccsrc/minddata/dataset/core/tensor.h
  3. +3
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc

+ 22
- 27
mindspore/ccsrc/minddata/dataset/core/tensor.cc View File

@@ -20,6 +20,7 @@
#include <iostream>
#include <fstream>
#include <functional>
#include <limits>
#include <memory>
#include <vector>
#include <utility>
@@ -60,29 +61,7 @@ namespace dataset {
out << std::hex << std::setw(2) << std::setfill('0') << o << std::dec << std::setfill(' '); \
break; \
}
/// Copy memory with no max limit since memcpy_s will fail when byte_size > 2^31 - 1 (SECUREC_MEM_MAX_LEN).
/// \param dest Destination buffer.
/// \param destMax Size of the destination buffer.
/// \param src Buffer to copy from.
/// \param count Number of characters to copy
/// \return Error number. Returns 0 for succuss copying.
errno_t memcpy_ss(uchar *dest, size_t destMax, const uchar *src, size_t count) {
uint32_t step = 0;
while (count >= SECUREC_MEM_MAX_LEN) {
int ret_code = memcpy_s(dest + step * SECUREC_MEM_MAX_LEN, destMax - step * SECUREC_MEM_MAX_LEN,
src + step * SECUREC_MEM_MAX_LEN, SECUREC_MEM_MAX_LEN);
if (ret_code != 0) {
return ret_code;
}
count -= SECUREC_MEM_MAX_LEN;
step++;
}
if (count > 0) {
return memcpy_s(dest + step * SECUREC_MEM_MAX_LEN, destMax - step * SECUREC_MEM_MAX_LEN,
src + step * SECUREC_MEM_MAX_LEN, count);
}
return 0;
}

Tensor::Tensor(const TensorShape &shape, const DataType &type) : shape_(shape), type_(type), data_(nullptr) {
// grab the mem pool from global context and create the allocator for char data area
std::shared_ptr<MemoryPool> global_pool = GlobalContext::Instance()->mem_pool();
@@ -134,8 +113,16 @@ Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type,
if (src != nullptr) {
// Given the shape/type of this tensor, compute the data size and copy in the input bytes.
int64_t byte_size = (*out)->SizeInBytes();
int ret_code = memcpy_ss((*out)->data_, byte_size, src, byte_size);
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor.");
if (byte_size == 0) {
return Status::OK();
}
if (byte_size < SECUREC_MEM_MAX_LEN) {
int ret_code = memcpy_s((*out)->data_, byte_size, src, byte_size);
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor.");
} else {
auto ret_code = std::memcpy((*out)->data_, src, byte_size);
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == (*out)->data_, "Failed to copy data into tensor.");
}
}
return Status::OK();
}
@@ -156,8 +143,16 @@ Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type,
}

RETURN_IF_NOT_OK((*out)->AllocateBuffer(length));
int ret_code = memcpy_ss((*out)->data_, length, src, length);
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor.");
if (length == 0) {
return Status::OK();
}
if (length < SECUREC_MEM_MAX_LEN) {
int ret_code = memcpy_s((*out)->data_, length, src, length);
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor.");
} else {
auto ret_code = std::memcpy((*out)->data_, src, length);
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == (*out)->data_, "Failed to copy data into tensor.");
}

return Status::OK();
}


+ 2
- 1
mindspore/ccsrc/minddata/dataset/core/tensor.h View File

@@ -263,7 +263,8 @@ class Tensor {
if (value.length() != length) {
RETURN_STATUS_UNEXPECTED("Length of the new string does not match the item.");
}
memcpy_s(reinterpret_cast<char *>(ptr), length, value.c_str(), length);
int ret_code = memcpy_s(reinterpret_cast<char *>(ptr), length, value.c_str(), length);
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to set data into tensor.");

return Status::OK();
}


+ 3
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc View File

@@ -43,6 +43,9 @@ Status PKSamplerRT::InitSampler() {
// capture the total number of possible sample ids.
// Compute that here for this case to find the total number of samples that are available to return.
// (in this case, samples per class * total classes).
if (samples_per_class_ > std::numeric_limits<int64_t>::max() / static_cast<int64_t>(labels_.size())) {
RETURN_STATUS_UNEXPECTED("Overflow in counting num_rows");
}
num_rows_ = samples_per_class_ * static_cast<int64_t>(labels_.size());

// The user may have chosen to sample less than the total amount.


Loading…
Cancel
Save