Browse Source

fix some bugs

tags/v1.2.0-rc1
liyong 4 years ago
parent
commit
a4ac6f429b
3 changed files with 22 additions and 8 deletions
  1. +8
    -6
      mindspore/dataset/transforms/c_transforms.py
  2. +7
    -1
      mindspore/mindrecord/tools/cifar10.py
  3. +7
    -1
      mindspore/mindrecord/tools/cifar100.py

+ 8
- 6
mindspore/dataset/transforms/c_transforms.py View File

@@ -321,14 +321,16 @@ class Duplicate(cde.DuplicateOp):

class Unique(cde.UniqueOp):
"""
Return an output tensor containing all the unique elements of the input tensor in
the same order that they occur in the input tensor.
Perform the unique operation on the input tensor, only support transform one column each time.

Also return an index tensor that contains the index of each element of the
input tensor in the Unique output tensor.
Return 3 tensor: unique output tensor, index tensor, count tensor.

Finally, return a count tensor that contains the count of each element of
the output tensor in the input tensor.
Unique output tensor contains all the unique elements of the input tensor
in the same order that they occur in the input tensor.

Index tensor that contains the index of each element of the input tensor in the unique output tensor.

Count tensor that contains the count of each element of the output tensor in the input tensor.

Note:
Call batch op before calling this function.


+ 7
- 1
mindspore/mindrecord/tools/cifar10.py View File

@@ -57,7 +57,13 @@ def restricted_loads(s):
if isinstance(s, str):
raise TypeError("can not load pickle from unicode string")
f = io.BytesIO(s)
return RestrictedUnpickler(f, encoding='bytes').load()
try:
return RestrictedUnpickler(f, encoding='bytes').load()
except pickle.UnpicklingError:
raise RuntimeError("Not a valid Cifar10 Dataset.")
else:
raise RuntimeError("Unexpected error while Unpickling Cifar10 Dataset.")


class Cifar10:
"""


+ 7
- 1
mindspore/mindrecord/tools/cifar100.py View File

@@ -56,7 +56,13 @@ def restricted_loads(s):
if isinstance(s, str):
raise TypeError("can not load pickle from unicode string")
f = io.BytesIO(s)
return RestrictedUnpickler(f, encoding='bytes').load()
try:
return RestrictedUnpickler(f, encoding='bytes').load()
except pickle.UnpicklingError:
raise RuntimeError("Not a valid Cifar100 Dataset.")
else:
raise RuntimeError("Unexpected error while Unpickling Cifar100 Dataset.")


class Cifar100:
"""


Loading…
Cancel
Save