Browse Source

Slice Bug

tags/v0.5.0-beta
hesham 5 years ago
parent
commit
68030e6a4b
2 changed files with 19 additions and 2 deletions
  1. +2
    -2
      mindspore/ccsrc/dataset/kernels/data/slice_op.cc
  2. +17
    -0
      tests/ut/python/dataset/test_slice_op.py

+ 2
- 2
mindspore/ccsrc/dataset/kernels/data/slice_op.cc View File

@@ -33,8 +33,8 @@ Status SliceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
// if slice object was provided, indices should be empty. Generate indices from the slice object.
if (slice_.valid() && indices_.empty()) {
dsize_t len = input->shape()[0];
indices_ = slice_.Indices(len);
return input->Slice(output, indices_);
std::vector<dsize_t> indices = slice_.Indices(len);
return input->Slice(output, indices);
}

// if indices are not empty, slices should be invalid, use indices_ to slice


+ 17
- 0
tests/ut/python/dataset/test_slice_op.py View File

@@ -80,6 +80,22 @@ def test_slice_slice_obj_3s():
slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3))


def test_slice_multiple_rows():
dataset = [[1, 2], [3, 4, 5], [1], [1, 2, 3, 4, 5, 6, 7]]

def gen():
for row in dataset:
yield (np.array(row),)

data = ds.GeneratorDataset(gen, column_names=["col"])
indexing = slice(0, 4)
data = data.map(operations=ops.Slice(indexing))
for i, d in enumerate(data):
array = np.array(dataset[i])
array = array[indexing]
np.testing.assert_array_equal(array, d[0])


def test_slice_slice_obj_3s_double():
slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1))
slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1))
@@ -217,3 +233,4 @@ if __name__ == "__main__":
test_slice_slice_obj_1s_str()
test_slice_slice_obj_neg_str()
test_slice_exceptions_str()
test_slice_multiple_rows()

Loading…
Cancel
Save