Browse Source

!28247 Fix Python Vocab validation

Merge pull request !28247 from luoyang/vocab
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
bff8d6ae01
3 changed files with 47 additions and 31 deletions
  1. +27
    -11
      mindspore/python/mindspore/dataset/text/utils.py
  2. +12
    -16
      mindspore/python/mindspore/dataset/text/validators.py
  3. +8
    -4
      tests/ut/python/dataset/test_vocab.py

+ 27
- 11
mindspore/python/mindspore/dataset/text/utils.py View File

@@ -38,11 +38,21 @@ class Vocab:
It contains a map that maps each word(str) to an id(int) or reverse.
"""

@check_vocab
def __init__(self, vocab):
self.c_vocab = vocab
def __init__(self):
self.c_vocab = None

def vocab(self):
"""
Get the vocabory table in dict type.

Returns:
A vocabulary consisting of word and id pairs.

Examples:
>>> vocab = text.Vocab.from_list(["word_1", "word_2", "word_3", "word_4"])
>>> vocabory_dict = vocab.vocab()
"""
check_vocab(self.c_vocab)
return self.c_vocab.vocab()

@check_tokens_to_ids
@@ -61,6 +71,7 @@ class Vocab:
>>> vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True)
>>> ids = vocab.tokens_to_ids(["w1", "w3"])
"""
check_vocab(self.c_vocab)
if isinstance(tokens, str):
tokens = [tokens]
return self.c_vocab.tokens_to_ids(tokens)
@@ -81,6 +92,7 @@ class Vocab:
>>> vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True)
>>> token = vocab.ids_to_tokens(0)
"""
check_vocab(self.c_vocab)
if isinstance(ids, int):
ids = [ids]
return self.c_vocab.ids_to_tokens(ids)
@@ -123,8 +135,9 @@ class Vocab:
... special_first=True)
>>> dataset = dataset.map(operations=text.Lookup(vocab, "<unk>"), input_columns=["text"])
"""
c_vocab = dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)
return Vocab(c_vocab)
vocab = Vocab()
vocab.c_vocab = dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)
return vocab

@classmethod
@check_from_list
@@ -147,8 +160,9 @@ class Vocab:
"""
if special_tokens is None:
special_tokens = []
c_vocab = cde.Vocab.from_list(word_list, special_tokens, special_first)
return Vocab(c_vocab)
vocab = Vocab()
vocab.c_vocab = cde.Vocab.from_list(word_list, special_tokens, special_first)
return vocab

@classmethod
@check_from_file
@@ -177,8 +191,9 @@ class Vocab:
vocab_size = -1
if special_tokens is None:
special_tokens = []
c_vocab = cde.Vocab.from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
return Vocab(c_vocab)
vocab = Vocab()
vocab.c_vocab = cde.Vocab.from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
return vocab

@classmethod
@check_from_dict
@@ -196,8 +211,9 @@ class Vocab:
Examples:
>>> vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6})
"""
c_vocab = cde.Vocab.from_dict(word_dict)
return Vocab(c_vocab)
vocab = Vocab()
vocab.c_vocab = cde.Vocab.from_dict(word_dict)
return vocab


class SentencePieceVocab(cde.SentencePieceVocab):


+ 12
- 16
mindspore/python/mindspore/dataset/text/validators.py View File

@@ -77,21 +77,14 @@ def check_from_file(method):
return new_method


def check_vocab(method):
"""A wrapper that wraps a parameter checker to the original function."""

@wraps(method)
def new_method(self, *args, **kwargs):
[vocab], _ = parse_user_args(method, *args, **kwargs)
if not isinstance(vocab, cde.Vocab):
type_error = "Input vocab is not an instance of cde.Vocab, got type {0}. ".format(type(vocab))
suggestion = "Use Vocab.from_dataset(), Vocab.from_list(), Vocab.from_file() or Vocab.from_dict() " \
"to build a vocab."
raise TypeError(type_error + suggestion)
def check_vocab(c_vocab):
"""Check the c_vocab of Vocab is initialized or not"""

return method(self, *args, **kwargs)

return new_method
if not isinstance(c_vocab, cde.Vocab):
error = "The Vocab has not built yet, got type {0}. ".format(type(c_vocab))
suggestion = "Use Vocab.from_dataset(), Vocab.from_list(), Vocab.from_file() or Vocab.from_dict() " \
"to build a Vocab."
raise RuntimeError(error + suggestion)


def check_tokens_to_ids(method):
@@ -117,9 +110,12 @@ def check_ids_to_tokens(method):
def new_method(self, *args, **kwargs):
[ids], _ = parse_user_args(method, *args, **kwargs)
type_check(ids, (int, list), "ids")
if isinstance(ids, int):
check_value(ids, (0, INT32_MAX), "ids")
if isinstance(ids, list):
param_names = ["ids[{0}]".format(i) for i in range(len(ids))]
type_check_list(ids, (int,), param_names)
for index, id_ in enumerate(ids):
type_check(id_, (int,), "ids[{}]".format(index))
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))

return method(self, *args, **kwargs)



+ 8
- 4
tests/ut/python/dataset/test_vocab.py View File

@@ -87,17 +87,21 @@ def test_vocab_exception():
"""
Feature: Python text.Vocab class
Description: test exceptions of text.Vocab
Expectation: raise TypeError when input is wrong.
Expectation: raise RuntimeError when vocab is not initialized, raise TypeError when input is wrong.
"""
with pytest.raises(TypeError):
text.Vocab(1)
vocab = text.Vocab()
with pytest.raises(RuntimeError):
vocab.ids_to_tokens(2)
with pytest.raises(RuntimeError):
vocab.tokens_to_ids(["w3"])

vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True)

with pytest.raises(TypeError):
vocab.ids_to_tokens("abc")
with pytest.raises(TypeError):
vocab.ids_to_tokens([2, 1.2, "abc"])
with pytest.raises(ValueError):
vocab.ids_to_tokens(-2)

with pytest.raises(TypeError):
vocab.tokens_to_ids([1, "w3"])


Loading…
Cancel
Save