You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

field.py 29 kB

7 years ago
6 years ago
6 years ago
6 years ago
6 years ago
7 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. from numbers import Number
  2. import torch
  3. import numpy as np
  4. from typing import Any
  5. from abc import abstractmethod
  6. from copy import deepcopy
  7. from collections import Counter
  8. class SetInputOrTargetException(Exception):
  9. def __init__(self, msg, index=None, field_name=None):
  10. super().__init__(msg)
  11. self.msg = msg
  12. self.index = index # 标示在哪个数据遭遇到问题了
  13. self.field_name = field_name # 标示当前field的名称
  14. class AppendToTargetOrInputException(Exception):
  15. def __init__(self, msg, index=None, field_name=None):
  16. super().__init__(msg)
  17. self.msg = msg
  18. self.index = index # 标示在哪个数据遭遇到问题了
  19. self.field_name = field_name # 标示当前field的名称
  20. class FieldArray:
  21. def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False,
  22. use_1st_ins_infer_dim_type=True):
  23. if len(content)==0:
  24. raise RuntimeError("Empty fieldarray is not allowed.")
  25. _content = content
  26. try:
  27. _content = list(_content)
  28. except BaseException as e:
  29. print(f"Cannot convert content(of type:{type(content)}) into list.")
  30. raise e
  31. self.name = name
  32. self.content = _content
  33. self._ignore_type = ignore_type
  34. # 根据input的情况设置input,target等
  35. self._cell_ndim = None # 多少维度
  36. self.dtype = None # 最内层的element都是什么类型的
  37. self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
  38. self._is_input = False
  39. self._is_target = False
  40. if is_input:
  41. self.is_input = is_input
  42. if is_target:
  43. self.is_target = is_target
  44. if padder is None:
  45. padder = AutoPadder(pad_val=0)
  46. else:
  47. assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder."
  48. padder = deepcopy(padder)
  49. self.set_padder(padder)
  50. @property
  51. def ignore_type(self):
  52. return self._ignore_type
  53. @ignore_type.setter
  54. def ignore_type(self, value):
  55. if value:
  56. self._cell_ndim = None
  57. self.dtype = None
  58. self._ignore_type = value
  59. @property
  60. def is_input(self):
  61. return self._is_input
  62. @is_input.setter
  63. def is_input(self, value):
  64. """
  65. 当 field_array.is_input = True / False 时被调用
  66. """
  67. # 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False)
  68. if value is True and \
  69. self._is_target is False and \
  70. self._ignore_type is False:
  71. self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
  72. if value is False and self._is_target is False:
  73. self.dtype = None
  74. self._cell_ndim = None
  75. self._is_input = value
  76. @property
  77. def is_target(self):
  78. return self._is_target
  79. @is_target.setter
  80. def is_target(self, value):
  81. """
  82. 当 field_array.is_target = True / False 时被调用
  83. """
  84. if value is True and \
  85. self._is_input is False and \
  86. self._ignore_type is False:
  87. self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
  88. if value is False and self._is_input is False:
  89. self.dtype = None
  90. self._cell_ndim = None
  91. self._is_target = value
  92. def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True):
  93. """
  94. 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有
  95. 通过将直接报错.
  96. :param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim
  97. :return:
  98. """
  99. cell_0 = self.content[0]
  100. index = 0
  101. try:
  102. type_0, dim_0 = _get_ele_type_and_dim(cell_0)
  103. if not only_check_1st_ins_dim_type:
  104. for cell in self.content[1:]:
  105. index += 1
  106. type_i, dim_i = _get_ele_type_and_dim(cell)
  107. if type_i!=type_0:
  108. raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}."
  109. ".".format(type_i, index, type_0))
  110. if dim_0!=dim_i:
  111. raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with "
  112. "dimension:{}.".format(dim_i, index, dim_0))
  113. self._cell_ndim = dim_0
  114. self.dtype = type_0
  115. except SetInputOrTargetException as e:
  116. e.index = index
  117. raise e
  118. def append(self, val:Any):
  119. """
  120. :param val: 把该val append到fieldarray。
  121. :return:
  122. """
  123. if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type:
  124. type_, dim_ = _get_ele_type_and_dim(val)
  125. if self.dtype!=type_:
  126. raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with "
  127. f"previous values(type:{self.dtype}).")
  128. if self._cell_ndim!=dim_:
  129. raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with "
  130. f"previous values(dim:{self._cell_ndim}).")
  131. self.content.append(val)
  132. else:
  133. self.content.append(val)
  134. def pop(self, index):
  135. """
  136. 删除该field中index处的元素
  137. :param int index: 从0开始的数据下标。
  138. :return:
  139. """
  140. self.content.pop(index)
  141. def __getitem__(self, indices):
  142. return self.get(indices, pad=False)
  143. def __setitem__(self, idx, val):
  144. assert isinstance(idx, int)
  145. if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型
  146. type_, dim_ = _get_ele_type_and_dim(val)
  147. if self.dtype!=type_:
  148. raise RuntimeError(f"Value(type:{type_}) are of different types with "
  149. f"other values(type:{self.dtype}).")
  150. if self._cell_ndim!=dim_:
  151. raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with "
  152. f"previous values(dim:{self._cell_ndim}).")
  153. self.content[idx] = val
  154. def get(self, indices, pad=True):
  155. """
  156. 根据给定的indices返回内容
  157. :param int,List[int] indices: 获取indices对应的内容。
  158. :param bool pad: 是否对返回的结果进行padding。仅对indices为List[int]时有效
  159. :return: 根据给定的indices返回的内容,可能是单个值或List
  160. """
  161. if isinstance(indices, int):
  162. return self.content[indices]
  163. if self.is_input is False and self.is_target is False:
  164. raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name))
  165. contents = [self.content[i] for i in indices]
  166. if self.padder is None or pad is False:
  167. return np.array(contents)
  168. else:
  169. return self.pad(contents)
  170. def pad(self, contents):
  171. return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
  172. def set_padder(self, padder):
  173. """
  174. 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。
  175. :param padder: :class:`~fastNLP.Padder` 类型,设置为None即删除padder。
  176. """
  177. if padder is not None:
  178. assert isinstance(padder, Padder), "padder must be of type Padder."
  179. self.padder = deepcopy(padder)
  180. else:
  181. self.padder = None
  182. def set_pad_val(self, pad_val):
  183. """
  184. 修改padder的pad_val.
  185. :param int pad_val: 该field的pad值设置为该值。
  186. """
  187. if self.padder is not None:
  188. self.padder.set_pad_val(pad_val)
  189. return self
  190. def __len__(self):
  191. """
  192. Returns the size of FieldArray.
  193. :return int length:
  194. """
  195. return len(self.content)
  196. def to(self, other):
  197. """
  198. 将other的属性复制给本FieldArray(other必须为FieldArray类型).
  199. 属性包括 is_input, is_target, padder, ignore_type
  200. :param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性
  201. :return: :class:`~fastNLP.FieldArray`
  202. """
  203. assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other))
  204. self.ignore_type = other.ignore_type
  205. self.is_input = other.is_input
  206. self.is_target = other.is_target
  207. self.padder = other.padder
  208. return self
  209. def split(self, sep:str=None, inplace:bool=True):
  210. """
  211. 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值
  212. :param sep: 分割符,如果为None则直接调用str.split()。
  213. :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
  214. :return: List[List[str]] or self
  215. """
  216. new_contents = []
  217. for index, cell in enumerate(self.content):
  218. try:
  219. new_contents.append(cell.split(sep))
  220. except Exception as e:
  221. print(f"Exception happens when process value in index {index}.")
  222. raise e
  223. return self._after_process(new_contents, inplace=inplace)
  224. def int(self, inplace:bool=True):
  225. """
  226. 将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
  227. (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
  228. :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
  229. :return: List[int], List[List[int]], self
  230. """
  231. new_contents = []
  232. for index, cell in enumerate(self.content):
  233. try:
  234. if isinstance(cell, list):
  235. new_contents.append([int(value) for value in cell])
  236. else:
  237. new_contents.append(int(cell))
  238. except Exception as e:
  239. print(f"Exception happens when process value in index {index}.")
  240. print(e)
  241. return self._after_process(new_contents, inplace=inplace)
  242. def float(self, inplace=True):
  243. """
  244. 将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
  245. (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
  246. :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
  247. :return:
  248. """
  249. new_contents = []
  250. for index, cell in enumerate(self.content):
  251. try:
  252. if isinstance(cell, list):
  253. new_contents.append([float(value) for value in cell])
  254. else:
  255. new_contents.append(float(cell))
  256. except Exception as e:
  257. print(f"Exception happens when process value in index {index}.")
  258. raise e
  259. return self._after_process(new_contents, inplace=inplace)
  260. def bool(self, inplace=True):
  261. """
  262. 将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
  263. (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
  264. :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
  265. :return:
  266. """
  267. new_contents = []
  268. for index, cell in enumerate(self.content):
  269. try:
  270. if isinstance(cell, list):
  271. new_contents.append([bool(value) for value in cell])
  272. else:
  273. new_contents.append(bool(cell))
  274. except Exception as e:
  275. print(f"Exception happens when process value in index {index}.")
  276. raise e
  277. return self._after_process(new_contents, inplace=inplace)
  278. def lower(self, inplace=True):
  279. """
  280. 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
  281. (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
  282. :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
  283. :return: List[int], List[List[int]], self
  284. """
  285. new_contents = []
  286. for index, cell in enumerate(self.content):
  287. try:
  288. if isinstance(cell, list):
  289. new_contents.append([value.lower() for value in cell])
  290. else:
  291. new_contents.append(cell.lower())
  292. except Exception as e:
  293. print(f"Exception happens when process value in index {index}.")
  294. raise e
  295. return self._after_process(new_contents, inplace=inplace)
  296. def upper(self, inplace=True):
  297. """
  298. 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
  299. (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
  300. :param inplace: 如果为True,则将新生成值替换本field。否则返回list。
  301. :return: List[int], List[List[int]], self
  302. """
  303. new_contents = []
  304. for index, cell in enumerate(self.content):
  305. try:
  306. if isinstance(cell, list):
  307. new_contents.append([value.upper() for value in cell])
  308. else:
  309. new_contents.append(cell.upper())
  310. except Exception as e:
  311. print(f"Exception happens when process value in index {index}.")
  312. raise e
  313. return self._after_process(new_contents, inplace=inplace)
  314. def value_count(self):
  315. """
  316. 返回该field下不同value的数量。多用于统计label数量
  317. :return: Counter, key是label,value是出现次数
  318. """
  319. count = Counter()
  320. def cum(cell):
  321. if _is_iterable(cell) and not isinstance(cell, str):
  322. for cell_ in cell:
  323. cum(cell_)
  324. else:
  325. count[cell] += 1
  326. for cell in self.content:
  327. cum(cell)
  328. return count
  329. def _after_process(self, new_contents, inplace):
  330. """
  331. 当调用处理函数之后,决定是否要替换field。
  332. :param new_contents:
  333. :param inplace:
  334. :return: self或者生成的content
  335. """
  336. if inplace:
  337. self.content = new_contents
  338. try:
  339. self.is_input = self.is_input
  340. self.is_target = self.is_input
  341. except SetInputOrTargetException as e:
  342. print("The newly generated field cannot be set as input or target.")
  343. raise e
  344. return self
  345. else:
  346. return new_contents
  347. def _get_ele_type_and_dim(cell:Any, dim=0):
  348. """
  349. 识别cell的类别与dimension的数量
  350. numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
  351. :param cell:
  352. :param dim:
  353. :return:
  354. """
  355. if isinstance(cell, (str, Number, np.bool_)):
  356. if hasattr(cell, 'dtype'):
  357. return cell.dtype.type, dim
  358. return type(cell), dim
  359. elif isinstance(cell, list):
  360. dim += 1
  361. res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
  362. types = set([i for i,j in res])
  363. dims = set([j for i,j in res])
  364. if len(types)>1:
  365. raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
  366. elif len(types)==0:
  367. raise SetInputOrTargetException("Empty value encountered.")
  368. if len(dims)>1:
  369. raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
  370. return types.pop(), dims.pop()
  371. elif isinstance(cell, torch.Tensor):
  372. return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0
  373. elif isinstance(cell, np.ndarray):
  374. if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了
  375. return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等
  376. # 否则需要继续往下iterate
  377. dim += 1
  378. res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
  379. types = set([i for i,j in res])
  380. dims = set([j for i,j in res])
  381. if len(types)>1:
  382. raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
  383. elif len(types)==0:
  384. raise SetInputOrTargetException("Empty value encountered.")
  385. if len(dims)>1:
  386. raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
  387. return types.pop(), dims.pop()
  388. else: # 包含tuple, set, dict以及其它的类型
  389. raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")
  390. def _is_iterable(value):
  391. # 检查是否是iterable的, duck typing
  392. try:
  393. iter(value)
  394. return True
  395. except BaseException as e:
  396. return False
  397. class Padder:
  398. """
  399. 别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder`
  400. 所有padder都需要继承这个类,并覆盖__call__方法。
  401. 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。
  402. .. py:function:: __call__(self, contents, field_name, field_ele_dtype):
  403. 传入的是List内容。假设有以下的DataSet。
  404. :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
  405. deepcopy一份。
  406. :param str, field_name: field的名称。
  407. :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。
  408. :return: np.array([padded_element])
  409. """
  410. def __init__(self, pad_val=0, **kwargs):
  411. self.pad_val = pad_val
  412. def set_pad_val(self, pad_val):
  413. self.pad_val = pad_val
  414. @abstractmethod
  415. def __call__(self, contents, field_name, field_ele_dtype, dim:int):
  416. """
  417. 传入的是List内容。假设有以下的DataSet。
  418. :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
  419. deepcopy一份。
  420. :param str, field_name: field的名称。
  421. :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,
  422. 该这个值为None。
  423. :param dim: 这个field的维度。当ignore_type为True时,该值为None
  424. :return: np.array([padded_element])
  425. Example::
  426. from fastNLP import DataSet
  427. from fastNLP import Instance
  428. dataset = DataSet()
  429. dataset.append(Instance(sent='this is a demo', length=4,
  430. chars=[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']]))
  431. dataset.append(Instance(sent='another one', length=2,
  432. chars=[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']]))
  433. 如果调用
  434. batch = dataset.get([0,1], pad=True)
  435. sent这个field的padder的__call__会接收到的内容会是
  436. [
  437. 'this is a demo',
  438. 'another one'
  439. ]
  440. length这个field的padder的__call__会接收到的内容会是
  441. [4, 2]
  442. chars这个field的padder的__call__会接收到的内容会是
  443. [
  444. [['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']],
  445. [['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']]
  446. ]
  447. 即把每个instance中某个field的内容合成一个List传入
  448. """
  449. raise NotImplementedError
  450. class AutoPadder(Padder):
  451. """
  452. 别名::class:`fastNLP.AutoPadder` :class:`fastNLP.core.field.AutoPadder`
  453. 根据contents的数据自动判定是否需要做padding。
  454. 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类
  455. 型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad
  456. 2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等
  457. 2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding
  458. 2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。
  459. 2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用
  460. :class: fastNLP.EngChar2DPadder.
  461. 2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片
  462. 的情况。
  463. 3 其它情况不进行处理,返回一个np.array类型。
  464. """
  465. def __init__(self, pad_val=0):
  466. super().__init__(pad_val=pad_val)
  467. def __call__(self, contents, field_name, field_ele_dtype, dim):
  468. if field_ele_dtype:
  469. if dim>3:
  470. return np.array(contents)
  471. if isinstance(field_ele_dtype, type) and \
  472. (issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)):
  473. if dim==0:
  474. array = np.array(contents, dtype=field_ele_dtype)
  475. elif dim==1:
  476. max_len = max(map(len, contents))
  477. array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
  478. for i, content_i in enumerate(contents):
  479. array[i, :len(content_i)] = content_i
  480. elif dim==2:
  481. max_len = max(map(len, contents))
  482. max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
  483. content_i in contents])
  484. array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype)
  485. for i, content_i in enumerate(contents):
  486. for j, content_ii in enumerate(content_i):
  487. array[i, j, :len(content_ii)] = content_ii
  488. else:
  489. shape = np.shape(contents)
  490. if len(shape)==4: # 说明各dimension是相同的大小
  491. array = np.array(contents, dtype=field_ele_dtype)
  492. else:
  493. raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
  494. return array
  495. elif str(field_ele_dtype).startswith('torch'):
  496. if dim==0:
  497. tensor = torch.tensor(contents).to(field_ele_dtype)
  498. elif dim==1:
  499. max_len = max(map(len, contents))
  500. tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype)
  501. for i, content_i in enumerate(contents):
  502. tensor[i, :len(content_i)] = torch.tensor(content_i)
  503. elif dim==2:
  504. max_len = max(map(len, contents))
  505. max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
  506. content_i in contents])
  507. tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val,
  508. dtype=field_ele_dtype)
  509. for i, content_i in enumerate(contents):
  510. for j, content_ii in enumerate(content_i):
  511. tensor[i, j, :len(content_ii)] = torch.tensor(content_ii)
  512. else:
  513. shapes = set([np.shape(content_i) for content_i in contents])
  514. if len(shapes)>1:
  515. raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
  516. shape = shapes.pop()
  517. if len(shape)==3:
  518. tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype)
  519. for i, content_i in enumerate(contents):
  520. tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype)
  521. else:
  522. raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
  523. return tensor
  524. else:
  525. return np.array(contents) # 不进行任何操作
  526. else:
  527. return np.array(contents)
  528. class EngChar2DPadder(Padder):
  529. """
  530. 别名::class:`fastNLP.EngChar2DPadder` :class:`fastNLP.core.field.EngChar2DPadder`
  531. 用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']],
  532. 但这个Padder只能处理index为int的情况。
  533. padded过后的batch内容,形状为(batch_size, max_sentence_length, max_word_length). max_sentence_length为这个batch中最大句
  534. 子长度;max_word_length为这个batch中最长的word的长度::
  535. from fastNLP import DataSet
  536. from fastNLP import EngChar2DPadder
  537. from fastNLP import Vocabulary
  538. dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']})
  539. dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars')
  540. vocab = Vocabulary()
  541. vocab.from_dataset(dataset, field_name='chars')
  542. vocab.index_dataset(dataset, field_name='chars')
  543. dataset.set_input('chars')
  544. padder = EngChar2DPadder()
  545. dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder
  546. """
  547. def __init__(self, pad_val=0, pad_length=0):
  548. """
  549. :param pad_val: int, pad的位置使用该index
  550. :param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度
  551. 都pad或截取到该长度.
  552. """
  553. super().__init__(pad_val=pad_val)
  554. self.pad_length = pad_length
  555. def __call__(self, contents, field_name, field_ele_dtype, dim):
  556. """
  557. 期望输入类似于
  558. [
  559. [[0, 2], [2, 3, 4], ..],
  560. [[9, 8, 2, 4], [1, 2,], ...],
  561. ....
  562. ]
  563. :param contents:
  564. :param field_name:
  565. :param field_ele_dtype
  566. :return:
  567. """
  568. if field_ele_dtype not in (np.int64, np.float64, int, float):
  569. raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format(
  570. field_name, field_ele_dtype
  571. ))
  572. assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions."
  573. if self.pad_length < 1:
  574. max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents])
  575. else:
  576. max_char_length = self.pad_length
  577. max_sent_length = max(len(word_lst) for word_lst in contents)
  578. batch_size = len(contents)
  579. dtype = type(contents[0][0][0])
  580. padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val,
  581. dtype=dtype)
  582. for b_idx, word_lst in enumerate(contents):
  583. for c_idx, char_lst in enumerate(word_lst):
  584. chars = char_lst[:max_char_length]
  585. padded_array[b_idx, c_idx, :len(chars)] = chars
  586. return padded_array