| @@ -49,7 +49,7 @@ class CheckpointCallback(Callback): | |||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。 | |||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
| 如果传入了 ``model_save_fn`` 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 | |||
| :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
| 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
| :param save_evaluate_results: 是否保存 evaluate 的结果。如果为 ``True`` ,在保存 topk 模型的 folder 中还将额外保存一个 | |||
| @@ -10,7 +10,7 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||
| class EarlyStopCallback(HasMonitorCallback): | |||
| """ | |||
| 用于 early stop 的 callback 。当监控的结果连续多少次没有变好边 raise 一个 EarlyStopException 。 | |||
| 用于 early stop 的 callback 。当监控的结果连续多少次没有变好便 raise 一个 EarlyStopException 。 | |||
| :param monitor: 监控的 metric 值。 | |||
| @@ -10,7 +10,7 @@ class LRSchedCallback(Callback): | |||
| 根据 ``step_on`` 参数在合适的时机调用 scheduler 的 step 函数。 | |||
| :param scheduler: 实现了 :meth:`step` 函数的对象; | |||
| :param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 ``batch`` 的话在每次更新参数 | |||
| :param step_on: 可选 ``['batch', 'epoch']`` 表示在何时调用 scheduler 的 step 函数。如果为 ``batch`` 的话在每次更新参数 | |||
| 之前调用;如果为 ``epoch`` 则是在一个 epoch 运行结束后调用; | |||
| """ | |||
| def __init__(self, scheduler, step_on:str='batch'): | |||
| @@ -69,9 +69,9 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||
| :param topk_monitor: 如果需要根据当前 callback 中的 evaluate 结果保存。这个参数是指在当前 callback 中的 evaluate 结果寻找 | |||
| :param topk_larger_better: ``topk_monitor`` 的值是否是越大越好。 | |||
| :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | |||
| 时间戳文件夹中。如果为 ``··``None`` ,默认使用当前文件夹。 | |||
| 时间戳文件夹中。如果为 ``None`` ,默认使用当前文件夹。 | |||
| :param only_state_dict: 保存模型时是否只保存 state_dict 。当 ``model_save_fn`` 不为 ``None`` 时,该参数无效。 | |||
| :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
| 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
| :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | |||
| @@ -24,7 +24,7 @@ class Saver: | |||
| - folder_name # 由 save() 调用时传入。 | |||
| :param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||
| :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| 保存 ``trainer`` 对象的话,将会保存 :class:~fastNLP.Trainer 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
| 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
| :param only_state_dict: 保存时是否仅保存权重,在 model_save_fn 不为 None 时无意义。 | |||
| @@ -191,7 +191,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||
| 的 ``monitor`` 值请返回 ``None`` 。 | |||
| :param larger_better: 该 monitor 是否越大越好。 | |||
| :param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||
| :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| :param save_object: 可选 ``['trainer', 'model']`` ,表示在保存时的保存对象为 ``trainer+model`` 还是 只是 ``model`` 。如果 | |||
| 保存 ``trainer`` 对象的话,将会保存 :class:`~fastNLP.Trainer` 的相关状态,可以通过 :meth:`Trainer.load_checkpoint` 加载该断 | |||
| 点继续训练。如果保存的是 ``Model`` 对象,则可以通过 :meth:`Trainer.load_model` 加载该模型权重。 | |||
| :param only_state_dict: 保存时是否仅保存权重,在 ``model_save_fn`` 不为 None 时无意义。 | |||
| @@ -85,27 +85,33 @@ def _get_backend() -> str: | |||
| class Collator: | |||
| """ | |||
| 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||
| 哦安定一个 field 是否可以 pad 的方式为:(1)当前这个 field 是否所有对象都是一样的数据类型;(因此,如果某 field 的数据有些是float | |||
| 有些是 int 将知道该 field 被判定为不可 pad 类型。)(2)当前这个 field 是否每个 sample 都具有一样的深度;(因此,例如有个 field 的 | |||
| 数据转为 batch 类型后为 [1, [1,2]], 会被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同)(3)当前这个 field 的类 | |||
| 型是否是可以 pad (例如 str 类型的数据)。可以通过设置 logger.setLevel('debug') 来打印是判定不可 pad 的原因。 | |||
| 判定一个 field 是否可以 pad 的方式为: | |||
| 1. 当前这个 field 是否所有对象都是一样的数据类型;比如,如果某 field 的数据有些是 float ,有些是 int ,则该 field 将被 | |||
| 判定为不可 pad 类型; | |||
| 2. 当前这个 field 是否每个 sample 都具有一样的深度;比如,如果某 field 的数据转为 batch 类型后为 ``[1, [1,2]]``, 则会 | |||
| 被判定为不可 pad ,因为第一个 sample 与 第二个 sample 深度不同; | |||
| 3. 当前这个 field 的类型是否是可以 pad (例如 str 类型的数据)。可以通过设置 ``logger.setLevel('debug')`` 来打印是判定不可 | |||
| pad 的原因。 | |||
| .. note:: | |||
| ``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个``field``应该使用哪种类型的 ``Padder``,如果第一个 ``batch`` | |||
| 的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 ``set_pad()`` 函数手动设置一下。 | |||
| ``Collator`` 的原理是使用第一个 ``batch`` 的数据尝试推断每个 ``field`` 应该使用哪种类型的 ``Padder``,如果第一个 ``batch`` | |||
| 的数据刚好比较特殊,可能导致在之后的 pad 中遭遇失败,这种情况请通过 :meth:`set_pad` 函数手动设置一下。 | |||
| todo 补充 code example 。 | |||
| .. todo:: | |||
| 补充 code example 。 | |||
| 如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 pad_val 设置为 None 实现。 | |||
| 如果需要将某个本可以 pad 的 field 设置为不可 pad ,则可以通过 :meth:`~fastNLP.Collator.set_pad` 的 ``pad_val`` 设置为 ``None`` 实现。 | |||
| 如果需要某些 field 不要包含在 pad 之后的结果中,可以使用 :meth:`~fastNLP.Collator.set_ignore` 进行设置。 | |||
| Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应 | |||
| 的 Padder 给对应的 field 。 | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','oneflow','numpy','raw', auto, None]。 | |||
| 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||
| 的数据返回一定是 list 。 | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``。 | |||
| 若为 ``'auto'`` ,则在进行 pad 的时候会根据调用的环境决定其 ``backend`` 。该参数对不能进行 pad 的数据没有影响,无法 pad 的数据返回一定 | |||
| 是 :class:`list` 。 | |||
| """ | |||
| def __init__(self, backend='auto'): | |||
| self.unpack_batch_func = None | |||
| @@ -192,20 +198,20 @@ class Collator: | |||
| """ | |||
| 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
| :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
| 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
| 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
| :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
| field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||
| 无意义。 | |||
| :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
| :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||
| torch.Tensor, paddle.Tensor, jittor.Var oneflow.Tensor 类型。若 pad_val 为 None ,该值无意义 。 | |||
| :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
| batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
| 形式,输出将被直接作为结果输出。 | |||
| :return: 返回 Collator 自身 | |||
| :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
| 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
| 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 | |||
| :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
| field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, | |||
| 该值无意义。 | |||
| :param dtype: 对于需要 pad 的 field ,该 field 的数据 ``dtype`` 应该是什么。 | |||
| :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, | |||
| :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 | |||
| 若 ``pad_val`` 为 ``None`` ,该值无意义 。 | |||
| :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 | |||
| batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 | |||
| :return: 返回 Collator 自身; | |||
| """ | |||
| self._renew() | |||
| @@ -275,8 +281,8 @@ class Collator: | |||
| """ | |||
| 设置可以 pad 的 field 默认 pad 为什么类型的 tensor | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None], | |||
| 若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ``['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None]``, | |||
| 若为 ``'auto'`` ,则在进行 pad 的时候会自动根据调用的环境决定其 ``backend`` ; | |||
| :return: | |||
| """ | |||
| assert backend in SUPPORTED_BACKENDS | |||
| @@ -289,10 +295,10 @@ class Collator: | |||
| >>> collator = Collator().set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
| __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
| :return: 返回 Collator 自身 | |||
| :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; | |||
| 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 | |||
| :return: 返回 Collator 自身; | |||
| """ | |||
| self._renew() | |||
| input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) | |||
| @@ -70,7 +70,7 @@ class SequencePackerUnpacker: | |||
| @staticmethod | |||
| def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict: | |||
| """ | |||
| 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||
| 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [2, 2]} | |||
| :param batch: 需要 unpack 的 batch 数据。 | |||
| :param ignore_fields: 需要忽略的 field 。 | |||
| @@ -16,13 +16,13 @@ from .exceptions import * | |||
| def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: | |||
| """ | |||
| 根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。 | |||
| 根据 参数 与 ``batch_field`` ,返回适合于当前 ``batch_field`` 的 *padder* 。 | |||
| :param batch_field: 将某 field 的内容组合成一个 batch 传入。 | |||
| :param pad_val: | |||
| :param batch_field: 将某 field 的内容组合成一个 batch 传入; | |||
| :param pad_val: | |||
| :param backend: | |||
| :param dtype: | |||
| :param field_name: 方便报错的。 | |||
| :param field_name: field 名称,方便在报错时显示; | |||
| :return: | |||
| """ | |||
| try: | |||
| @@ -84,14 +84,14 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
| class JittorNumberPadder(Padder): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| 可以将形如 [1, 2, 3] 这类的数据转为 jittor.Var([1, 2, 3]) | |||
| """ | |||
| 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``jittor.Var([1, 2, 3])`` | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||
| """ | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @@ -106,23 +106,23 @@ class JittorNumberPadder(Padder): | |||
| class JittorSequencePadder(Padder): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 jittor.Var([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
| """ | |||
| 可以将形如 ``[[1], [1, 2]]`` 这类的数据转为 ``jittor.Var([[1], [1, 2]])`` | |||
| :param pad_val: 需要 pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||
| """ | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| :param batch_field 输入的某个 field 的 batch 数据。 | |||
| :param pad_val 需要填充的值 | |||
| :dtype 数据的类型 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||
| return tensor | |||
| @@ -131,11 +131,11 @@ class JittorSequencePadder(Padder): | |||
| class JittorTensorPadder(Padder): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| 目前支持 [jittor.Var([3, 2], jittor.Var([1])] 类似的。若内部元素不为 jittor.Var ,则必须含有 tolist() 方法。 | |||
| 目前支持 ``[jittor.Var([3, 2], jittor.Var([1])]`` 类似的输入。若内部元素不为 :class:`jittor.Var` ,则必须含有 :meth:`tolist` 方法。 | |||
| :param pad_val: 需要 pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||
| :param pad_val: 需要 pad 的值; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`jittor.Var` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`jittor.long`, :class:`jittor.float32`, :class:`int`, :class:`float` 等 | |||
| """ | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @@ -143,11 +143,11 @@ class JittorTensorPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 batch_field 数据 转为 jittor.Var 并 pad 到相同长度。 | |||
| 将 ``batch_field`` 数据 转为 :class:`jittor.Var` 并 pad 到相同长度。 | |||
| :param batch_field 输入的某个 field 的 batch 数据。 | |||
| :param pad_val 需要填充的值 | |||
| :dtype 数据的类型 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| try: | |||
| if not isinstance(batch_field[0], jittor.Var): | |||
| @@ -38,15 +38,15 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
| class NumpyNumberPadder(Padder): | |||
| """ | |||
| 可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) 。可以通过: | |||
| 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``np.array([1, 2, 3])`` 。可以通过:: | |||
| >>> NumpyNumberPadder.pad([1, 2, 3]) | |||
| 使用。 | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么 | |||
| :param pad_val: 该值无意义; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||
| :param dtype: 输出的数据的 dtype ; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| @@ -54,21 +54,28 @@ class NumpyNumberPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| return np.array(batch_field, dtype=dtype) | |||
| class NumpySequencePadder(Padder): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
| 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``np.array([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。 | |||
| 可以通过以下的方式直接使用: | |||
| >>> NumpySequencePadder.pad([[1], [1, 2]], pad_val=-100, dtype=float) | |||
| [[ 1. -100.] | |||
| [ 1. 2.]] | |||
| :param pad_val: pad 的值是多少。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么 | |||
| :param pad_val: pad 的值是多少; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||
| :param dtype: 输出的数据的 dtype ; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| @@ -76,18 +83,25 @@ class NumpySequencePadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) | |||
| class NumpyTensorPadder(Padder): | |||
| """ | |||
| pad 类似于 [np.array([3, 4]), np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 | |||
| pad 类似于 ``[np.array([3, 4]), np.array([1])]`` 的 field 。若内部元素不为 :class:`np.ndarray` ,则必须含有 :meth:`tolist` 方法。 | |||
| >>> NumpyTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||
| [[ 3. 4.] | |||
| [ 1. -100.]] | |||
| :param pad_val: pad 的值是多少。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么 | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| @@ -96,6 +110,13 @@ class NumpyTensorPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`numpy.array` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| try: | |||
| if not isinstance(batch_field[0], np.ndarray): | |||
| batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] | |||
| @@ -74,11 +74,11 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
| class OneflowNumberPadder(Padder): | |||
| """ | |||
| 可以将形如 [1, 2, 3] 这类的数据转为 oneflow.Tensor([1, 2, 3]) | |||
| 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``oneflow.Tensor([1, 2, 3])``。 | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 oneflow.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 oneflow.long, oneflow.float32, int, float 等 | |||
| :param pad_val: 该值无意义; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型; | |||
| :param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -86,16 +86,23 @@ class OneflowNumberPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| return oneflow.tensor(batch_field, dtype=dtype) | |||
| class OneflowSequencePadder(Padder): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 oneflow.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
| 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``oneflow.Tensor([[1, 0], [1, 2]])``, 可以 pad 多重嵌套的数据。 | |||
| :param pad_val: 需要 pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 oneflow.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 oneflow.long, oneflow.float32, int, float 等 | |||
| :param pad_val: 需要 pad 的值; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型; | |||
| :param type: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -103,13 +110,20 @@ class OneflowSequencePadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| tensor = get_padded_oneflow_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||
| return tensor | |||
| class OneflowTensorPadder(Padder): | |||
| """ | |||
| 目前支持 [oneflow.tensor([3, 2], oneflow.tensor([1])] 类似的。若内部元素不为 oneflow.tensor ,则必须含有 tolist() 方法。 | |||
| 目前支持 ``[oneflow.tensor([3, 2], oneflow.tensor([1])]`` 类似的输入,若内部元素不为 :class:`oneflow.Tensor` ,则必须含有 :meth:`tolist` 方法。 | |||
| >>> OneflowTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||
| [[ 3. 4.] | |||
| @@ -119,8 +133,8 @@ class OneflowTensorPadder(Padder): | |||
| [ 1, -100]]) | |||
| :param pad_val: 需要 pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 oneflow.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 oneflow.long, oneflow.float32, int, float 等 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`oneflow.Tensor` 类型。 | |||
| :param dtype: 输出的数据的 dtype,。如 :class:`oneflow.long`, :class:`oneflow.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -128,6 +142,13 @@ class OneflowTensorPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`oneflow.Tensor` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| device = None | |||
| try: | |||
| if not isinstance(batch_field[0], oneflow.Tensor): | |||
| @@ -1,7 +1,7 @@ | |||
| class Padder: | |||
| """ | |||
| 所有 Padder 对象父类,所有的 Padder 对象都会实现 pad(batch_field, pad_val=0, dtype=None) 的静态函数。 | |||
| 所有 **Padder** 对象的父类,所有的 Padder 对象都会实现静态函数 *pad(batch_field, pad_val=0, dtype=None)* 。 | |||
| """ | |||
| def __init__(self, pad_val, dtype): | |||
| @@ -99,11 +99,11 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
| class PaddleNumberPadder(Padder): | |||
| """ | |||
| 可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) | |||
| 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``paddle.Tensor([1, 2, 3])`` | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||
| :param pad_val: 该值无意义; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||
| @@ -112,16 +112,23 @@ class PaddleNumberPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| return paddle.to_tensor(batch_field, dtype=dtype) | |||
| class PaddleSequencePadder(Padder): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
| 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``paddle.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。 | |||
| :param pad_val: pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; | |||
| """ | |||
| def __init__(self, ele_dtype=None, pad_val=0, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -129,17 +136,30 @@ class PaddleSequencePadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||
| return tensor | |||
| class PaddleTensorPadder(Padder): | |||
| """ | |||
| 目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 | |||
| 目前支持 ``[paddle.tensor([3, 2], paddle.tensor([2, 1])]`` 类似的输入,若内部元素不为 :class:`paddle.Tensor` ,则必须含有 :meth:`tolist` 方法。 | |||
| >>> PaddleTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||
| [[ 3. 4.] | |||
| [ 1. -100.]] | |||
| >>> PaddleTensorPadder.pad([paddle.to_tensor([3, 4]), paddle.to_tensor([1])], pad_val=-100) | |||
| tensor([[ 3, 4], | |||
| [ 1, -100]]) | |||
| :param pad_val: pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`paddle.Tensor` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`int`, :class:`float`, :class:`int32` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -147,6 +167,13 @@ class PaddleTensorPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| 将 ``batch_field`` 数据 转为 :class:`paddle.Tensor` 并 pad 到相同长度。 | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 数据的类型 | |||
| """ | |||
| try: | |||
| if not isinstance(batch_field[0], paddle.Tensor): | |||
| batch_field = [np.array(field.tolist()) for field in batch_field] | |||
| @@ -34,11 +34,11 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
| class RawNumberPadder(Padder): | |||
| """ | |||
| 可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 | |||
| 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``[1, 2, 3]`` 。实际上该 padder 无意义。 | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么 | |||
| :param pad_val: | |||
| :param ele_dtype: | |||
| :param dtype: | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| @@ -54,11 +54,11 @@ class RawNumberPadder(Padder): | |||
| class RawSequencePadder(Padder): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||
| 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。 | |||
| :param pad_val: pad 的值 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么 | |||
| :param pad_val: pad 的值; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||
| :param dtype: 输出的数据的 dtype ; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| @@ -68,8 +68,8 @@ class RawSequencePadder(Padder): | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| :param batch_field: | |||
| :param pad_val: | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 该参数无意义。 | |||
| :return: | |||
| """ | |||
| @@ -78,11 +78,11 @@ class RawSequencePadder(Padder): | |||
| class RawTensorPadder(Padder): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||
| 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``[[1, 0], [1, 2]]`` 。可以 pad 多重嵌套的数据。 | |||
| :param pad_val: pad 的值 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么 | |||
| :param pad_val: pad 的值; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`np.array` 类型; | |||
| :param dtype: 输出的数据的 dtype ; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| @@ -92,8 +92,8 @@ class RawTensorPadder(Padder): | |||
| def pad(batch_field, pad_val=0, dtype=None): | |||
| """ | |||
| :param batch_field: | |||
| :param pad_val: | |||
| :param batch_field: 输入的某个 field 的 batch 数据。 | |||
| :param pad_val: 需要填充的值 | |||
| :param dtype: 该参数无意义。 | |||
| :return: | |||
| """ | |||
| @@ -77,11 +77,11 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
| class TorchNumberPadder(Padder): | |||
| """ | |||
| 可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) | |||
| 可以将形如 ``[1, 2, 3]`` 这类的数据转为 ``torch.Tensor([1, 2, 3])`` | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||
| :param pad_val: 该值无意义; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -94,11 +94,11 @@ class TorchNumberPadder(Padder): | |||
| class TorchSequencePadder(Padder): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
| 将类似于 ``[[1], [1, 2]]`` 的内容 pad 为 ``torch.Tensor([[1, 0], [1, 2]])`` 可以 pad 多重嵌套的数据。 | |||
| :param pad_val: 需要 pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||
| :param pad_val: 需要 pad 的值; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -112,7 +112,7 @@ class TorchSequencePadder(Padder): | |||
| class TorchTensorPadder(Padder): | |||
| """ | |||
| 目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 | |||
| 目前支持 ``[torch.tensor([3, 2], torch.tensor([1])]`` 类似的输入。若内部元素不为 :class:`torch.Tensor` ,则必须含有 :meth:`tolist` 方法。 | |||
| >>> TorchTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||
| [[ 3. 4.] | |||
| @@ -121,9 +121,9 @@ class TorchTensorPadder(Padder): | |||
| tensor([[ 3, 4], | |||
| [ 1, -100]]) | |||
| :param pad_val: 需要 pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||
| :param pad_val: 需要 pad 的值; | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 :class:`torch.Tensor` 类型; | |||
| :param dtype: 输出的数据的 dtype 是什么。如 :class:`torch.long`, :class:`torch.float32`, :class:`int`, :class:`float` 等; | |||
| """ | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| @@ -5,6 +5,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| __all__ = [] | |||
| def is_torch_tensor_dtype(dtype) -> bool: | |||
| """ | |||
| @@ -78,13 +78,12 @@ def fill_array(batch_field:List, padded_batch:np.ndarray): | |||
| def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: | |||
| """ | |||
| 例如: | |||
| [[1,2], [3]] -> np.array([[1, 2], [3, 0]]) | |||
| 将输入 pad 为 :class:`numpy.arraay` 类型,如:``[[1,2], [3]] -> np.array([[1, 2], [3, 0]])`` | |||
| :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||
| /4d(多为图片)。 | |||
| :param dtype: 目标类别是什么 | |||
| :param pad_val: pad 的 value | |||
| :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 **1d**(多为句子长度)/ **2d**(多为文本序列)/ **3d**(多为字符序列) | |||
| /4d(多为图片); | |||
| :param dtype: 输出数据的 dtype 类型; | |||
| :param pad_val: 填充值; | |||
| :return: | |||
| """ | |||
| shapes = get_shape(batch_field) | |||