|
|
|
@@ -16,6 +16,7 @@ |
|
|
|
validators for text ops |
|
|
|
""" |
|
|
|
from functools import wraps |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import mindspore._c_dataengine as cde |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
@@ -93,10 +94,10 @@ def check_tokens_to_ids(method): |
|
|
|
@wraps(method) |
|
|
|
def new_method(self, *args, **kwargs): |
|
|
|
[tokens], _ = parse_user_args(method, *args, **kwargs) |
|
|
|
type_check(tokens, (str, list), "tokens") |
|
|
|
type_check(tokens, (str, list, np.ndarray), "tokens") |
|
|
|
if isinstance(tokens, list): |
|
|
|
param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))] |
|
|
|
type_check_list(tokens, (str,), param_names) |
|
|
|
type_check_list(tokens, (str, np.str_), param_names) |
|
|
|
|
|
|
|
return method(self, *args, **kwargs) |
|
|
|
|
|
|
|
@@ -109,12 +110,12 @@ def check_ids_to_tokens(method): |
|
|
|
@wraps(method) |
|
|
|
def new_method(self, *args, **kwargs): |
|
|
|
[ids], _ = parse_user_args(method, *args, **kwargs) |
|
|
|
type_check(ids, (int, list), "ids") |
|
|
|
type_check(ids, (int, list, np.ndarray), "ids") |
|
|
|
if isinstance(ids, int): |
|
|
|
check_value(ids, (0, INT32_MAX), "ids") |
|
|
|
if isinstance(ids, list): |
|
|
|
for index, id_ in enumerate(ids): |
|
|
|
type_check(id_, (int,), "ids[{}]".format(index)) |
|
|
|
type_check(id_, (int, np.int_), "ids[{}]".format(index)) |
|
|
|
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index)) |
|
|
|
|
|
|
|
return method(self, *args, **kwargs) |
|
|
|
|