Browse Source

fix vocab validation

feature/build-system-rewrite
luoyang 4 years ago
parent
commit
c248274d91
3 changed files with 21 additions and 4 deletions
  1. +4
    -0
      mindspore/python/mindspore/dataset/text/utils.py
  2. +5
    -4
      mindspore/python/mindspore/dataset/text/validators.py
  3. +12
    -0
      tests/ut/python/dataset/test_vocab.py

+ 4
- 0
mindspore/python/mindspore/dataset/text/utils.py View File

@@ -72,6 +72,8 @@ class Vocab:
>>> ids = vocab.tokens_to_ids(["w1", "w3"])
"""
check_vocab(self.c_vocab)
if isinstance(tokens, np.ndarray):
tokens = tokens.tolist()
if isinstance(tokens, str):
tokens = [tokens]
return self.c_vocab.tokens_to_ids(tokens)
@@ -93,6 +95,8 @@ class Vocab:
>>> token = vocab.ids_to_tokens(0)
"""
check_vocab(self.c_vocab)
if isinstance(ids, np.ndarray):
ids = ids.tolist()
if isinstance(ids, int):
ids = [ids]
return self.c_vocab.ids_to_tokens(ids)


+ 5
- 4
mindspore/python/mindspore/dataset/text/validators.py View File

@@ -16,6 +16,7 @@
validators for text ops
"""
from functools import wraps
import numpy as np

import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype
@@ -93,10 +94,10 @@ def check_tokens_to_ids(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[tokens], _ = parse_user_args(method, *args, **kwargs)
type_check(tokens, (str, list), "tokens")
type_check(tokens, (str, list, np.ndarray), "tokens")
if isinstance(tokens, list):
param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))]
type_check_list(tokens, (str,), param_names)
type_check_list(tokens, (str, np.str_), param_names)

return method(self, *args, **kwargs)

@@ -109,12 +110,12 @@ def check_ids_to_tokens(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[ids], _ = parse_user_args(method, *args, **kwargs)
type_check(ids, (int, list), "ids")
type_check(ids, (int, list, np.ndarray), "ids")
if isinstance(ids, int):
check_value(ids, (0, INT32_MAX), "ids")
if isinstance(ids, list):
for index, id_ in enumerate(ids):
type_check(id_, (int,), "ids[{}]".format(index))
type_check(id_, (int, np.int_), "ids[{}]".format(index))
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))

return method(self, *args, **kwargs)


+ 12
- 0
tests/ut/python/dataset/test_vocab.py View File

@@ -60,6 +60,12 @@ def test_vocab_tokens_to_ids():
ids = vocab.tokens_to_ids("hello")
assert ids == -1

ids = vocab.tokens_to_ids(np.array(["w1", "w3"]))
assert ids == [1, 3]

ids = vocab.tokens_to_ids(np.array("w1"))
assert ids == 1


def test_vocab_ids_to_tokens():
"""
@@ -82,6 +88,12 @@ def test_vocab_ids_to_tokens():
tokens = vocab.ids_to_tokens(7)
assert tokens == ""

tokens = vocab.ids_to_tokens(np.array([2, 3]))
assert tokens == ["w2", "w3"]

tokens = vocab.ids_to_tokens(np.array(2))
assert tokens == "w2"


def test_vocab_exception():
"""


Loading…
Cancel
Save