Browse Source

!3912 fix numpyslice bug

Merge pull request !3912 from luoyang/son_r0.6
pull/3912/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
801660ef08
2 changed files with 14 additions and 0 deletions
  1. +4
    -0
      mindspore/ccsrc/minddata/dataset/core/tensor.cc
  2. +10
    -0
      tests/ut/python/dataset/test_dataset_numpy_slices.py

+ 4
- 0
mindspore/ccsrc/minddata/dataset/core/tensor.cc View File

@@ -268,6 +268,10 @@ Status Tensor::CreateTensor(std::shared_ptr<Tensor> *ptr, py::array arr) {
std::shared_ptr<MemoryPool> global_pool = GlobalContext::Instance()->mem_pool();
(*ptr)->data_allocator_ = std::make_unique<Allocator<unsigned char>>(global_pool);
int64_t byte_size = (*ptr)->SizeInBytes();
if (byte_size == 0) {
return Status::OK();
}

RETURN_IF_NOT_OK((*ptr)->AllocateBuffer(byte_size));

unsigned char *data = static_cast<unsigned char *>(arr.request().ptr);


+ 10
- 0
tests/ut/python/dataset/test_dataset_numpy_slices.py View File

@@ -82,6 +82,16 @@ def test_numpy_slices_dict_1():
assert data[0] == res[i][0]
assert data[1] == res[i][1]

def test_numpy_slices_dict_2():
logger.info("Test Dictionary empty data.")

np_data = {"a": [[]], "b": [[4]]}
ds = de.NumpySlicesDataset(np_data, shuffle=False)
res = [[], [4]]

for _, data in enumerate(ds):
np.testing.assert_array_almost_equal(data[0], np.array(res[0]))
np.testing.assert_array_almost_equal(data[1], np.array(res[1]))

def test_numpy_slices_tuple_1():
logger.info("Test slicing a list of tuple.")


Loading…
Cancel
Save