# Conflicts: # test/core/test_trainer.pytags/v0.2.0^2
| @@ -11,7 +11,7 @@ class FieldArray(object): | |||||
| """ | """ | ||||
| :param str name: the name of the FieldArray | :param str name: the name of the FieldArray | ||||
| :param list content: a list of int, float, or a list of list. | |||||
| :param list content: a list of int, float, str or np.ndarray, or a list of list of one. | |||||
| :param int padding_val: the integer for padding. Default: 0. | :param int padding_val: the integer for padding. Default: 0. | ||||
| :param bool is_target: If True, this FieldArray is used to compute loss. | :param bool is_target: If True, this FieldArray is used to compute loss. | ||||
| :param bool is_input: If True, this FieldArray is used to the model input. | :param bool is_input: If True, this FieldArray is used to the model input. | ||||
| @@ -27,35 +27,46 @@ class FieldArray(object): | |||||
| self.padding_val = padding_val | self.padding_val = padding_val | ||||
| self.is_target = is_target | self.is_target = is_target | ||||
| self.is_input = is_input | self.is_input = is_input | ||||
| self.BASIC_TYPES = (int, float, str, np.ndarray) | |||||
| self.is_2d_list = False | |||||
| self.pytype = self._type_detection(content) | self.pytype = self._type_detection(content) | ||||
| self.dtype = self._map_to_np_type(self.pytype) | self.dtype = self._map_to_np_type(self.pytype) | ||||
| @staticmethod | |||||
| def _type_detection(content): | |||||
| def _type_detection(self, content): | |||||
| """ | |||||
| :param content: a list of int, float, str or np.ndarray, or a list of list of one. | |||||
| :return type: one of int, float, str, np.ndarray | |||||
| """ | |||||
| if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | ||||
| # 2-D list | |||||
| # TODO: refactor | |||||
| type_set = set([type(item) for item in content[0]]) | |||||
| else: | |||||
| # 1-D list | |||||
| # content is a 2-D list | |||||
| type_set = set([self._type_detection(x) for x in content]) | |||||
| if len(type_set) > 1: | |||||
| raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||||
| self.is_2d_list = True | |||||
| return type_set.pop() | |||||
| elif isinstance(content, list): | |||||
| # content is a 1-D list | |||||
| if len(content) == 0: | if len(content) == 0: | ||||
| raise RuntimeError("Cannot create FieldArray with an empty list.") | raise RuntimeError("Cannot create FieldArray with an empty list.") | ||||
| type_set = set([type(item) for item in content]) | type_set = set([type(item) for item in content]) | ||||
| if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)): | |||||
| return type_set.pop() | |||||
| elif len(type_set) == 2 and float in type_set and int in type_set: | |||||
| # up-cast int to float | |||||
| for idx, _ in enumerate(content): | |||||
| content[idx] = float(content[idx]) | |||||
| return float | |||||
| if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | |||||
| return type_set.pop() | |||||
| elif len(type_set) == 2 and float in type_set and int in type_set: | |||||
| # up-cast int to float | |||||
| return float | |||||
| else: | |||||
| raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set)) | |||||
| else: | else: | ||||
| raise ValueError("Unsupported type conversion detected in FieldArray: {}".format(*type_set)) | |||||
| raise RuntimeError("Cannot create FieldArray with type {}".format(type(content))) | |||||
| @staticmethod | @staticmethod | ||||
| def _map_to_np_type(basic_type): | def _map_to_np_type(basic_type): | ||||
| type_mapping = {int: np.int64, float: np.float64, str: np.str} | |||||
| type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | |||||
| return type_mapping[basic_type] | return type_mapping[basic_type] | ||||
| def __repr__(self): | def __repr__(self): | ||||
| @@ -64,29 +75,35 @@ class FieldArray(object): | |||||
| def append(self, val): | def append(self, val): | ||||
| """Add a new item to the tail of FieldArray. | """Add a new item to the tail of FieldArray. | ||||
| :param val: int, float, str, or a list of them. | |||||
| :param val: int, float, str, or a list of one. | |||||
| """ | """ | ||||
| val_type = type(val) | val_type = type(val) | ||||
| if val_type is int and self.pytype is float: | |||||
| # up-cast the appended value | |||||
| val = float(val) | |||||
| elif val_type is float and self.pytype is int: | |||||
| # up-cast all other values in the content | |||||
| for idx, _ in enumerate(self.content): | |||||
| self.content[idx] = float(self.content[idx]) | |||||
| self.pytype = float | |||||
| self.dtype = self._map_to_np_type(self.pytype) | |||||
| elif val_type is list: | |||||
| if val_type == list: # shape check | |||||
| if self.is_2d_list is False: | |||||
| raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||||
| if len(val) == 0: | if len(val) == 0: | ||||
| raise ValueError("Cannot append an empty list.") | |||||
| raise RuntimeError("Cannot append an empty list.") | |||||
| val_list_type = set([type(_) for _ in val]) # type check | |||||
| if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||||
| # up-cast int to float | |||||
| val_type = float | |||||
| elif len(val_list_type) == 1: | |||||
| val_type = val_list_type.pop() | |||||
| else: | else: | ||||
| if type(val[0]) != self.pytype: | |||||
| raise ValueError( | |||||
| "Cannot append a list of {}-type value into a {}-tpye FieldArray.". | |||||
| format(type(val[0]), self.pytype)) | |||||
| elif val_type != self.pytype: | |||||
| raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype)) | |||||
| raise RuntimeError("Cannot append a list of {}".format(val_list_type)) | |||||
| else: | |||||
| if self.is_2d_list is True: | |||||
| raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||||
| if val_type == float and self.pytype == int: | |||||
| # up-cast | |||||
| self.pytype = float | |||||
| self.dtype = self._map_to_np_type(self.pytype) | |||||
| elif val_type == int and self.pytype == float: | |||||
| pass | |||||
| elif val_type == self.pytype: | |||||
| pass | |||||
| else: | |||||
| raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||||
| self.content.append(val) | self.content.append(val) | ||||
| def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
| @@ -102,7 +119,6 @@ class FieldArray(object): | |||||
| :param indices: an int, or a list of int. | :param indices: an int, or a list of int. | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # TODO: 返回行为不一致,有隐患 | |||||
| if isinstance(indices, int): | if isinstance(indices, int): | ||||
| return self.content[indices] | return self.content[indices] | ||||
| assert self.is_input is True or self.is_target is True | assert self.is_input is True or self.is_target is True | ||||
| @@ -126,6 +126,7 @@ class LossBase(object): | |||||
| for keys, val in target_dict.items(): | for keys, val in target_dict.items(): | ||||
| param_val_dict.update({keys: val}) | param_val_dict.update({keys: val}) | ||||
| # TODO: use the origin key to raise error | |||||
| if not self._checked: | if not self._checked: | ||||
| for keys in args: | for keys in args: | ||||
| if param_map[keys] not in param_val_dict.keys(): | if param_map[keys] not in param_val_dict.keys(): | ||||
| @@ -1,3 +1,4 @@ | |||||
| numpy>=1.14.2 | numpy>=1.14.2 | ||||
| torch>=0.4.0 | torch>=0.4.0 | ||||
| tensorboardX | tensorboardX | ||||
| tqdm | |||||
| @@ -24,19 +24,31 @@ class TestFieldArray(unittest.TestCase): | |||||
| def test_type_conversion(self): | def test_type_conversion(self): | ||||
| fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) | fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) | ||||
| self.assertEqual(fa.pytype, float) | self.assertEqual(fa.pytype, float) | ||||
| self.assertEqual(fa.dtype, np.double) | |||||
| self.assertEqual(fa.dtype, np.float64) | |||||
| fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | ||||
| fa.append(1.3333) | fa.append(1.3333) | ||||
| self.assertEqual(fa.pytype, float) | self.assertEqual(fa.pytype, float) | ||||
| self.assertEqual(fa.dtype, np.double) | |||||
| self.assertEqual(fa.dtype, np.float64) | |||||
| fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False) | fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False) | ||||
| fa.append(10) | fa.append(10) | ||||
| self.assertEqual(fa.pytype, float) | self.assertEqual(fa.pytype, float) | ||||
| self.assertEqual(fa.dtype, np.double) | |||||
| self.assertEqual(fa.dtype, np.float64) | |||||
| fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False) | fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False) | ||||
| fa.append("e") | fa.append("e") | ||||
| self.assertEqual(fa.dtype, np.str) | self.assertEqual(fa.dtype, np.str) | ||||
| self.assertEqual(fa.pytype, str) | self.assertEqual(fa.pytype, str) | ||||
| def test_support_np_array(self): | |||||
| fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False) | |||||
| self.assertEqual(fa.dtype, np.ndarray) | |||||
| fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | |||||
| self.assertEqual(fa.pytype, np.ndarray) | |||||
| def test_nested_list(self): | |||||
| fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False) | |||||
| self.assertEqual(fa.pytype, float) | |||||
| self.assertEqual(fa.dtype, np.float64) | |||||
| @@ -1,8 +1,8 @@ | |||||
| import unittest | import unittest | ||||
| import numpy as np | import numpy as np | ||||
| from torch import nn | |||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch import nn | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| @@ -27,6 +27,7 @@ def prepare_fake_dataset(): | |||||
| [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | ||||
| return data_set | return data_set | ||||
| def prepare_fake_dataset2(*args, size=100): | def prepare_fake_dataset2(*args, size=100): | ||||
| ys = np.random.randint(4, size=100) | ys = np.random.randint(4, size=100) | ||||
| data = {'y': ys} | data = {'y': ys} | ||||
| @@ -34,6 +35,7 @@ def prepare_fake_dataset2(*args, size=100): | |||||
| data[arg] = np.random.randn(size, 5) | data[arg] = np.random.randn(size, 5) | ||||
| return DataSet(data=data) | return DataSet(data=data) | ||||
| class TrainerTestGround(unittest.TestCase): | class TrainerTestGround(unittest.TestCase): | ||||
| def test_case(self): | def test_case(self): | ||||
| data_set = prepare_fake_dataset() | data_set = prepare_fake_dataset() | ||||
| @@ -56,15 +58,20 @@ class TrainerTestGround(unittest.TestCase): | |||||
| check_code_level=2, | check_code_level=2, | ||||
| use_tqdm=True) | use_tqdm=True) | ||||
| trainer.train() | trainer.train() | ||||
| """ | |||||
| # 应该正确运行 | |||||
| """ | |||||
| def test_trainer_suggestion1(self): | def test_trainer_suggestion1(self): | ||||
| # 检查报错提示能否正确提醒用户。 | # 检查报错提示能否正确提醒用户。 | ||||
| # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | ||||
| dataset = prepare_fake_dataset2('x') | dataset = prepare_fake_dataset2('x') | ||||
| class Model(nn.Module): | class Model(nn.Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
| def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
| x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
| x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
| @@ -73,10 +80,12 @@ class TrainerTestGround(unittest.TestCase): | |||||
| return {'loss': loss} | return {'loss': loss} | ||||
| model = Model() | model = Model() | ||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model | |||||
| ) | |||||
| with self.assertRaises(NameError): | |||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model | |||||
| ) | |||||
| """ | """ | ||||
| # 应该获取到的报错提示 | # 应该获取到的报错提示 | ||||
| NameError: | NameError: | ||||
| @@ -92,10 +101,12 @@ class TrainerTestGround(unittest.TestCase): | |||||
| # 这里传入forward需要的数据,看是否可以运行 | # 这里传入forward需要的数据,看是否可以运行 | ||||
| dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
| dataset.set_input('x1', 'x2', 'y', flag=True) | dataset.set_input('x1', 'x2', 'y', flag=True) | ||||
| class Model(nn.Module): | class Model(nn.Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
| def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
| x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
| x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
| @@ -120,10 +131,12 @@ class TrainerTestGround(unittest.TestCase): | |||||
| # 这里传入forward需要的数据,但是forward没有返回loss这个key | # 这里传入forward需要的数据,但是forward没有返回loss这个key | ||||
| dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
| dataset.set_input('x1', 'x2', 'y', flag=True) | dataset.set_input('x1', 'x2', 'y', flag=True) | ||||
| class Model(nn.Module): | class Model(nn.Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
| def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
| x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
| x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
| @@ -221,7 +234,6 @@ class TrainerTestGround(unittest.TestCase): | |||||
| print_every=2 | print_every=2 | ||||
| ) | ) | ||||
| def test_case2(self): | def test_case2(self): | ||||
| # check metrics Wrong | # check metrics Wrong | ||||
| data_set = prepare_fake_dataset2('x1', 'x2') | data_set = prepare_fake_dataset2('x1', 'x2') | ||||