| @@ -0,0 +1,43 @@ | |||
| __all__ = [ | |||
| 'cache_results', | |||
| 'is_jittor_dataset', | |||
| 'jittor_collate_wraps', | |||
| 'paddle_to', | |||
| 'paddle_move_data_to_device', | |||
| 'get_paddle_device_id', | |||
| 'get_paddle_gpu_str', | |||
| 'is_in_paddle_dist', | |||
| 'is_in_fnlp_paddle_dist', | |||
| 'is_in_paddle_launch_dist', | |||
| 'f_rich_progress', | |||
| 'torch_paddle_move_data_to_device', | |||
| 'torch_move_data_to_device', | |||
| 'get_fn_arg_names', | |||
| 'check_fn_not_empty_params', | |||
| 'auto_param_call', | |||
| 'check_user_specific_params', | |||
| 'dataclass_to_dict', | |||
| 'match_and_substitute_params', | |||
| 'apply_to_collection', | |||
| 'nullcontext', | |||
| 'pretty_table_printer', | |||
| 'Option', | |||
| 'indice_collate_wrapper', | |||
| 'deprecated', | |||
| 'seq_len_to_mask', | |||
| 'synchronize_safe_rm', | |||
| 'synchronize_mkdir' | |||
| ] | |||
| from .cache_results import cache_results | |||
| from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||
| from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | |||
| is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||
| from .rich_progress import f_rich_progress | |||
| from .torch_paddle_utils import torch_paddle_move_data_to_device | |||
| from .torch_utils import torch_move_data_to_device | |||
| from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||
| dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||
| indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||
| @@ -0,0 +1,310 @@ | |||
| from datetime import datetime | |||
| import hashlib | |||
| import _pickle | |||
| import functools | |||
| import os | |||
| from typing import Callable, List, Any, Optional | |||
| import inspect | |||
| import ast | |||
| from collections import deque | |||
| __all__ = [ | |||
| 'cache_results' | |||
| ] | |||
| from fastNLP.core.log.logger import logger | |||
| from fastNLP.core.log.highlighter import ColorHighlighter | |||
| class FuncCallVisitor(ast.NodeVisitor): | |||
| # credit to https://gist.github.com/jargnar/0946ab1d985e2b4ab776 | |||
| def __init__(self): | |||
| self._name = deque() | |||
| @property | |||
| def name(self): | |||
| return '.'.join(self._name) | |||
| @name.deleter | |||
| def name(self): | |||
| self._name.clear() | |||
| def visit_Name(self, node): | |||
| self._name.appendleft(node.id) | |||
| def visit_Attribute(self, node): | |||
| try: | |||
| self._name.appendleft(node.attr) | |||
| self._name.appendleft(node.value.id) | |||
| except AttributeError: | |||
| self.generic_visit(node) | |||
| def get_func_calls(tree): | |||
| func_calls = [] | |||
| for node in ast.walk(tree): | |||
| if isinstance(node, ast.Call): | |||
| callvisitor = FuncCallVisitor() | |||
| callvisitor.visit(node.func) | |||
| func_calls.append(callvisitor.name) | |||
| if isinstance(node, ast.FunctionDef): | |||
| if not (node is tree): | |||
| func_calls.extend(get_func_calls(node)) | |||
| return func_calls | |||
| def truncate_start_blanks(source:str)->str: | |||
| """ | |||
| 将source中的每一行按照第一行的indent删掉多余的空格 | |||
| :param source: | |||
| :return: | |||
| """ | |||
| lines = source.split('\n') | |||
| num_blank = 0 | |||
| # get the top blank line | |||
| for line in lines: | |||
| if line: | |||
| num_blank = len(line) - len(line.lstrip()) | |||
| new_lines = [] | |||
| for line in lines: | |||
| i = -1 | |||
| for i in range(min(len(line), num_blank)): | |||
| if line[i] == ' ': | |||
| continue | |||
| else: | |||
| break | |||
| line = line[i:] | |||
| new_lines.append(line) | |||
| return '\n'.join(new_lines) | |||
| def _get_func_and_its_called_func_source_code(func) -> List[str]: | |||
| """ | |||
| 给定一个func,返回在这个函数里面用到的所有函数的源码。 | |||
| :param callable func: | |||
| :return: | |||
| """ | |||
| last_frame = inspect.currentframe().f_back.f_back.f_back | |||
| last_frame_f_local = last_frame.f_locals | |||
| last_frame_loc = {} | |||
| if 'loc' in last_frame_f_local: | |||
| last_frame_loc = last_frame_f_local['loc'] | |||
| func_calls = list(set(get_func_calls(ast.parse(truncate_start_blanks(inspect.getsource(func)))))) | |||
| func_calls.sort() | |||
| sources = [] | |||
| for _func_name in func_calls: | |||
| try: | |||
| if _func_name == 'cache_results': # ignore the decorator | |||
| continue | |||
| if '.' in _func_name: | |||
| _funcs = _func_name.split('.') | |||
| else: | |||
| _funcs = [_func_name] | |||
| if _funcs[0] in last_frame_f_local or _funcs[0] in last_frame_loc: | |||
| tmp = _funcs.pop(0) | |||
| variable = last_frame_f_local.get(tmp, last_frame_loc.get(tmp)) | |||
| while len(_funcs) or variable is not None: | |||
| if hasattr(variable, '__class__') and not inspect.isbuiltin(variable.__class__): | |||
| try: | |||
| sources.append(inspect.getsource(variable.__class__)) | |||
| except TypeError: | |||
| pass | |||
| if callable(variable) or inspect.isclass(variable): | |||
| sources.append(inspect.getsource(variable)) | |||
| if len(_funcs): | |||
| tmp = _funcs.pop(0) | |||
| if hasattr(variable, tmp): | |||
| variable = getattr(variable, tmp) | |||
| else: | |||
| break | |||
| else: | |||
| variable = None | |||
| except: | |||
| # some failure | |||
| pass | |||
| del last_frame # | |||
| sources.append(inspect.getsource(func)) | |||
| return sources | |||
| def _prepare_cache_filepath(filepath:str): | |||
| r""" | |||
| 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||
| :param filepath: str. | |||
| :return: None, if not, this function will raise error | |||
| """ | |||
| _cache_filepath = os.path.abspath(filepath) | |||
| if os.path.isdir(_cache_filepath): | |||
| raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||
| cache_dir = os.path.dirname(_cache_filepath) | |||
| if not os.path.exists(cache_dir): | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| class Hasher: | |||
| def __init__(self): | |||
| self.m = hashlib.sha1() | |||
| def update(self, value: Any) -> None: | |||
| if isinstance(value, str): | |||
| value = [value] | |||
| for x in value: | |||
| self.m.update(x.encode('utf8')) | |||
| def hexdigest(self) -> str: | |||
| return self.m.hexdigest() | |||
| def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = None): | |||
| if fn_kwargs is None: | |||
| fn_kwargs = {} | |||
| hasher = Hasher() | |||
| try: | |||
| sources = _get_func_and_its_called_func_source_code(fn) | |||
| hasher.update(sources) | |||
| except: | |||
| return "can't be hashed" | |||
| for key in sorted(fn_kwargs): | |||
| hasher.update(key) | |||
| try: | |||
| hasher.update(fn_kwargs[key]) | |||
| except: | |||
| pass | |||
| return hasher.hexdigest() | |||
| def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||
| r""" | |||
| cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | |||
| import time | |||
| import numpy as np | |||
| from fastNLP import cache_results | |||
| @cache_results('cache.pkl') | |||
| def process_data(): | |||
| # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | |||
| time.sleep(1) | |||
| return np.random.randint(10, size=(5,)) | |||
| start_time = time.time() | |||
| print("res =",process_data()) | |||
| print(time.time() - start_time) | |||
| start_time = time.time() | |||
| print("res =",process_data()) | |||
| print(time.time() - start_time) | |||
| # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | |||
| # Save cache to cache.pkl. | |||
| # res = [5 4 9 1 8] | |||
| # 1.0042750835418701 | |||
| # Read cache from cache.pkl. | |||
| # res = [5 4 9 1 8] | |||
| # 0.0040721893310546875 | |||
| 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: | |||
| # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | |||
| process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' | |||
| 上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | |||
| 'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | |||
| 上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: | |||
| process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 | |||
| # _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | |||
| :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | |||
| 函数调用的时候传入_cache_fp这个参数。 | |||
| :param bool _refresh: 是否重新生成cache。 | |||
| :param int _verbose: 是否打印cache的信息。 | |||
| :param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 | |||
| 与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 | |||
| 该修改对结果有影响,但无法做出warning。 | |||
| :return: | |||
| """ | |||
| def wrapper_(func): | |||
| signature = inspect.signature(func) | |||
| for key, _ in signature.parameters.items(): | |||
| if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'): | |||
| raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| fn_param = kwargs.copy() | |||
| if args: | |||
| params = [p.name for p in inspect.signature(func).parameters.values()] | |||
| fn_param.update(zip(params, args)) | |||
| if '_cache_fp' in kwargs: | |||
| cache_filepath = kwargs.pop('_cache_fp') | |||
| assert isinstance(cache_filepath, str), "_cache_fp can only be str." | |||
| else: | |||
| cache_filepath = _cache_fp | |||
| if '_refresh' in kwargs: | |||
| refresh = kwargs.pop('_refresh') | |||
| assert isinstance(refresh, bool), "_refresh can only be bool." | |||
| else: | |||
| refresh = _refresh | |||
| if '_verbose' in kwargs: | |||
| verbose = kwargs.pop('_verbose') | |||
| assert isinstance(verbose, int), "_verbose can only be integer." | |||
| else: | |||
| verbose = _verbose | |||
| if '_check_hash' in kwargs: | |||
| check_hash = kwargs.pop('_check_hash') | |||
| else: | |||
| check_hash = _check_hash | |||
| refresh_flag = True | |||
| new_hash_code = None | |||
| if check_hash: | |||
| new_hash_code = cal_fn_hash_code(func, fn_param) | |||
| if cache_filepath is not None and refresh is False: | |||
| # load data | |||
| if os.path.exists(cache_filepath): | |||
| cache_filepath = os.path.abspath(cache_filepath) | |||
| with open(cache_filepath, 'rb') as f: | |||
| results = _pickle.load(f) | |||
| old_hash_code = results['hash'] | |||
| save_time = results['save_time'] | |||
| results = results['results'] | |||
| if verbose == 1: | |||
| logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | |||
| if check_hash and old_hash_code != new_hash_code: | |||
| logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | |||
| f"difference may caused by the sourcecode change of the functions by this function.", | |||
| extra={'highlighter': ColorHighlighter('red')}) | |||
| refresh_flag = False | |||
| if refresh_flag: | |||
| if new_hash_code is None: | |||
| new_hash_code = cal_fn_hash_code(func, fn_param) | |||
| results = func(*args, **kwargs) | |||
| if cache_filepath is not None: | |||
| if results is None: | |||
| raise RuntimeError("The return value is None. Cannot save None results.") | |||
| cache_filepath = os.path.abspath(cache_filepath) | |||
| _prepare_cache_filepath(cache_filepath) | |||
| _dict = { | |||
| 'results': results, | |||
| 'hash': new_hash_code, | |||
| 'save_time': datetime.now(), | |||
| } | |||
| with open(cache_filepath, 'wb') as f: | |||
| _pickle.dump(_dict, f) | |||
| logger.info("Save cache to {}.".format(cache_filepath)) | |||
| return results | |||
| return wrapper | |||
| return wrapper_ | |||
| @@ -0,0 +1,4 @@ | |||
| class DummyClass: | |||
| pass | |||
| @@ -0,0 +1,51 @@ | |||
| __all__ = [ | |||
| 'is_jittor_dataset', | |||
| 'jittor_collate_wraps' | |||
| ] | |||
| from collections.abc import Mapping, Callable | |||
| from functools import wraps | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor as jt | |||
| from fastNLP.core.dataset import Instance | |||
| def is_jittor_dataset(dataset) -> bool: | |||
| try: | |||
| if isinstance(dataset, jt.dataset.Dataset): | |||
| return True | |||
| else: | |||
| return False | |||
| except BaseException: | |||
| return False | |||
| def jittor_collate_wraps(func, auto_collator: Callable): | |||
| """ | |||
| 对jittor的collate_fn进行wrap封装, 如果数据集为mapping类型,那么采用auto_collator,否则还是采用jittor自带的collate_batch | |||
| :param func: | |||
| :param auto_collator: | |||
| :return: | |||
| """ | |||
| @wraps(func) | |||
| def wrapper(batch): | |||
| if isinstance(batch[0], Instance): | |||
| if auto_collator is not None: | |||
| result = auto_collator(batch) | |||
| else: | |||
| raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!") | |||
| elif isinstance(batch[0], Mapping): | |||
| if auto_collator is not None: | |||
| result = auto_collator(batch) | |||
| else: | |||
| result = func(batch) | |||
| else: | |||
| result = func(batch) | |||
| return result | |||
| return wrapper | |||
| @@ -0,0 +1,89 @@ | |||
| __all__ = [ | |||
| "paddle_to", | |||
| "paddle_move_data_to_device", | |||
| "get_paddle_gpu_str", | |||
| "get_paddle_device_id", | |||
| "is_in_paddle_dist", | |||
| "is_in_fnlp_paddle_dist", | |||
| "is_in_paddle_launch_dist", | |||
| ] | |||
| import os | |||
| from typing import Any, Optional, Union | |||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||
| if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| from .utils import apply_to_collection | |||
| def paddle_to(data, device: Union[str, int]): | |||
| if device == "cpu": | |||
| return data.cpu() | |||
| else: | |||
| return data.cuda(get_paddle_device_id(device)) | |||
| def get_paddle_gpu_str(device: Union[str, int]): | |||
| """ | |||
| 获得 `gpu:x` 类型的设备名 | |||
| """ | |||
| if isinstance(device, str): | |||
| return device.replace("cuda", "gpu") | |||
| return f"gpu:{device}" | |||
| def get_paddle_device_id(device: Union[str, int]): | |||
| """ | |||
| 获得 gpu 的设备id,注意不要传入 `cpu` 。 | |||
| """ | |||
| if isinstance(device, int): | |||
| return device | |||
| if device == "cpu": | |||
| raise ValueError("Cannot get device id from `cpu`.") | |||
| return paddle.device._convert_to_place(device).get_device_id() | |||
| def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | |||
| data_device: Optional[str] = None) -> Any: | |||
| r""" | |||
| 将数据集合传输到给定设备。只有paddle.Tensor对象会被传输到设备中,其余保持不变 | |||
| :param batch: | |||
| :param device: `cpu`, `gpu` or `gpu:x` | |||
| :param data_device: | |||
| :return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
| """ | |||
| if device is None: | |||
| if data_device is not None: | |||
| device = data_device | |||
| else: | |||
| return batch | |||
| def batch_to(data: Any) -> Any: | |||
| return paddle_to(data, device) | |||
| return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | |||
| def is_in_paddle_dist(): | |||
| """ | |||
| 判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | |||
| """ | |||
| return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | |||
| def is_in_fnlp_paddle_dist(): | |||
| """ | |||
| 判断是否处于 FastNLP 拉起的分布式进程中 | |||
| """ | |||
| return FASTNLP_DISTRIBUTED_CHECK in os.environ | |||
| def is_in_paddle_launch_dist(): | |||
| """ | |||
| 判断是否处于 launch 启动的分布式进程中 | |||
| """ | |||
| return 'PADDLE_RANK_IN_NODE' in os.environ and \ | |||
| 'FLAGS_selected_gpus' in os.environ and \ | |||
| FASTNLP_DISTRIBUTED_CHECK not in os.environ | |||
| @@ -0,0 +1,214 @@ | |||
| """ | |||
| 该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能 | |||
| 不冲突 | |||
| """ | |||
| import sys | |||
| from typing import Any, Union, Optional | |||
| from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||
| from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | |||
| __all__ = [ | |||
| 'f_rich_progress' | |||
| ] | |||
| from fastNLP.envs import get_global_rank | |||
| class Singleton(type): | |||
| _instances = {} | |||
| def __call__(cls, *args, **kwargs): | |||
| if cls not in cls._instances: | |||
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||
| return cls._instances[cls] | |||
| # 如果不打印的时候,使得整个 progress 没有任何意义 | |||
| class DummyFRichProgress: | |||
| def __getattr__(self, item): | |||
| return DummyFRichProgress() | |||
| def __call__(self, *args, **kwargs): | |||
| # 防止用户通过 DummyFRichProgress.console.print() 这种调用 | |||
| return None | |||
| class FRichProgress(Progress, metaclass=Singleton): | |||
| """ | |||
| fastNLP 使用的 progress bar ,新增了 new_progress 函数,通过此函数即可定制 fastNLP 中所有 progress 的样式。 | |||
| """ | |||
| def new_progess(self, *columns: Union[str, ProgressColumn], | |||
| console: Optional[Console] = None, | |||
| auto_refresh: bool = True, | |||
| refresh_per_second: float = 10, | |||
| speed_estimate_period: float = 30.0, | |||
| transient: bool = True, | |||
| redirect_stdout: bool = True, | |||
| redirect_stderr: bool = True, | |||
| get_time: Optional[GetTimeCallable] = None, | |||
| disable: bool = False, | |||
| expand: bool = False): | |||
| """ | |||
| 重新初始化一个rich bar。如果columns不传入,则继续使用之前的column内容。 | |||
| :param progress: | |||
| :return: | |||
| """ | |||
| for task_id in self.task_ids: # 首先移除已有的 | |||
| self.remove_task(task_id) | |||
| assert ( | |||
| refresh_per_second is None or refresh_per_second > 0 | |||
| ), "refresh_per_second must be > 0" | |||
| # stop previous columns | |||
| self.stop() | |||
| # do not change these variables | |||
| # self._lock = RLock() | |||
| # self._tasks: Dict[TaskID, Task] = {} | |||
| # self._task_index: TaskID = TaskID(0) | |||
| if len(columns) != 0: | |||
| self.columns = columns | |||
| self.speed_estimate_period = speed_estimate_period | |||
| self.disable = disable | |||
| self.expand = expand | |||
| self.live = Live( | |||
| console=console or get_console(), | |||
| auto_refresh=auto_refresh, | |||
| refresh_per_second=refresh_per_second, | |||
| transient=transient, | |||
| redirect_stdout=redirect_stdout, | |||
| redirect_stderr=redirect_stderr, | |||
| get_renderable=self.get_renderable, | |||
| ) | |||
| self.get_time = get_time or self.console.get_time | |||
| self.print = self.console.print | |||
| self.log = self.console.log | |||
| # start new | |||
| self.start() | |||
| return self | |||
| def set_transient(self, transient: bool = True): | |||
| """ | |||
| 设置是否在bar运行结束之后不关闭 | |||
| :param transient: | |||
| :return: | |||
| """ | |||
| self.new_progess(transient=transient) | |||
| def set_disable(self, flag: bool = True): | |||
| """ | |||
| 设置当前 progress bar 的状态,如果为 True ,则不会显示进度条了。 | |||
| :param flag: | |||
| :return: | |||
| """ | |||
| self.disable = flag | |||
| def add_task( | |||
| self, | |||
| description: str, | |||
| start: bool = True, | |||
| total: float = 100.0, | |||
| completed: int = 0, | |||
| visible: bool = True, | |||
| **fields: Any, | |||
| ) -> TaskID: | |||
| if self.live._started is False: | |||
| self.start() | |||
| post_desc = fields.pop('post_desc', '') | |||
| return super().add_task(description=description, | |||
| start=start, | |||
| total=total, | |||
| completed=completed, | |||
| visible=visible, | |||
| post_desc=post_desc, | |||
| **fields) | |||
| def stop_task(self, task_id: TaskID) -> None: | |||
| if task_id in self._tasks: | |||
| super().stop_task(task_id) | |||
| def remove_task(self, task_id: TaskID) -> None: | |||
| if task_id in self._tasks: | |||
| super().remove_task(task_id) | |||
| def destroy_task(self, task_id: TaskID): | |||
| if task_id in self._tasks: | |||
| super().stop_task(task_id) | |||
| super().remove_task(task_id) | |||
| if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
| f_rich_progress = FRichProgress().new_progess( | |||
| "[progress.description]{task.description}", | |||
| "[progress.percentage]{task.percentage:>3.0f}%", | |||
| BarColumn(), | |||
| TimeElapsedColumn(), | |||
| "/", | |||
| TimeRemainingColumn(), | |||
| TextColumn("{task.fields[post_desc]}", justify="right"), | |||
| transient=True, | |||
| disable=False, | |||
| speed_estimate_period=10 | |||
| ) | |||
| else: | |||
| f_rich_progress = DummyFRichProgress() | |||
| if __name__ == '__main__': | |||
| f = DummyFRichProgress() | |||
| f.console.print('xxx') | |||
| f.console.print.print('xxx') | |||
| # 测试创建 | |||
| import time | |||
| n_steps = 10 | |||
| task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||
| for i in range(n_steps): | |||
| f_rich_progress.update(task_id, description=f'test:{i}', advance=1, refresh=True) | |||
| print(f"test:{i}") | |||
| time.sleep(0.3) | |||
| f_rich_progress.remove_task(task_id) | |||
| # 测试一下 inner/outer | |||
| n_steps = 5 | |||
| f_rich_progress.start() | |||
| outer_task_id = f_rich_progress.add_task(description='Outer:', total=n_steps) | |||
| inner_task_id = f_rich_progress.add_task(description='Inner:', total=n_steps) | |||
| for i in range(n_steps): | |||
| f_rich_progress.reset(inner_task_id, total=n_steps) | |||
| f_rich_progress.update(outer_task_id, description=f'Outer:{i}', advance=1, refresh=True) | |||
| for j in range(n_steps): | |||
| f_rich_progress.update(inner_task_id, description=f'Inner:{j}', advance=1, refresh=True, | |||
| post_desc='Loss: 0.334332323') | |||
| print(f"Outer:{i}, Inner:{j}") | |||
| time.sleep(0.3) | |||
| # 测试一下修改bar | |||
| f_rich_progress = FRichProgress().new_progess( | |||
| BarColumn(), | |||
| "[progress.description]{task.description}", | |||
| "[progress.percentage]{task.percentage:>3.0f}%", | |||
| TimeElapsedColumn(), | |||
| transient=True) | |||
| n_steps = 10 | |||
| task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||
| for i in range(n_steps): | |||
| f_rich_progress.update(task_id, description=f'test:{i}', advance=1) | |||
| print(f"test:{i}") | |||
| time.sleep(0.3) | |||
| f_rich_progress.remove_task(task_id) | |||
| f_rich_progress.stop() | |||
| @@ -0,0 +1,49 @@ | |||
| from typing import Any, Optional | |||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| __all__ = [ | |||
| "torch_paddle_move_data_to_device", | |||
| ] | |||
| from .utils import apply_to_collection | |||
| from .paddle_utils import paddle_to | |||
| def torch_paddle_move_data_to_device(batch: Any, device: Optional[str] = None, non_blocking: Optional[bool] = True, | |||
| data_device: Optional[str] = None) -> Any: | |||
| r""" | |||
| 将数据集合传输到给定设备。只有paddle.Tensor和torch.Tensor对象会被传输到设备中,其余保持不变 | |||
| :param batch: | |||
| :param device: | |||
| :param non_blocking: | |||
| :param data_device: | |||
| :return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
| """ | |||
| if device is None: | |||
| if data_device is not None: | |||
| device = data_device | |||
| else: | |||
| return batch | |||
| torch_device = device.replace("gpu", "cuda") | |||
| paddle_device = device.replace("cuda", "gpu") | |||
| def batch_to(data: Any) -> Any: | |||
| if isinstance(data, torch.Tensor): | |||
| data = data.to(torch_device, non_blocking=non_blocking) | |||
| elif isinstance(data, paddle.Tensor): | |||
| data = paddle_to(data, paddle_device) | |||
| return data | |||
| return apply_to_collection(batch, dtype=(paddle.Tensor, torch.Tensor), function=batch_to) | |||
| @@ -0,0 +1,63 @@ | |||
| from abc import ABC | |||
| from typing import Any, Union, Optional | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| __all__ = [ | |||
| 'torch_move_data_to_device' | |||
| ] | |||
| from .utils import apply_to_collection | |||
| class TorchTransferableDataType(ABC): | |||
| """ | |||
| A custom type for data that can be moved to a torch device via `.to(...)`. | |||
| Example: | |||
| >>> isinstance(dict, TorchTransferableDataType) | |||
| False | |||
| >>> isinstance(torch.rand(2, 3), TorchTransferableDataType) | |||
| True | |||
| >>> class CustomObject: | |||
| ... def __init__(self): | |||
| ... self.x = torch.rand(2, 2) | |||
| ... def to(self, device): | |||
| ... self.x = self.x.to(device) | |||
| ... return self | |||
| >>> isinstance(CustomObject(), TorchTransferableDataType) | |||
| True | |||
| """ | |||
| @classmethod | |||
| def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
| if cls is TorchTransferableDataType: | |||
| to = getattr(subclass, "to", None) | |||
| return callable(to) | |||
| return NotImplemented | |||
| def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, | |||
| non_blocking: Optional[bool] = True) -> Any: | |||
| r""" | |||
| 将数据集合传输到给定设备。任何定义方法 “to(device)” 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||
| :param batch: 应当迁移的数据; | |||
| :param device: 数据应当迁移到的设备;当该参数的值为 None 时,表示迁移数据的操作由用户自己完成,我们不需要经管; | |||
| :param non_blocking: pytorch 的迁移数据方法 `to` 的参数; | |||
| :return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
| """ | |||
| if device is None: | |||
| return batch | |||
| def batch_to(data: Any) -> Any: | |||
| kwargs = dict(non_blocking=non_blocking) if isinstance(data, torch.Tensor) else {} | |||
| data_output = data.to(device, **kwargs) | |||
| if data_output is not None: | |||
| return data_output | |||
| # user wrongly implemented the `TransferableDataType` and forgot to return `self`. | |||
| return data | |||
| dtype = TorchTransferableDataType | |||
| return apply_to_collection(batch, dtype=dtype, function=batch_to) | |||
| @@ -0,0 +1,591 @@ | |||
| import inspect | |||
| from inspect import Parameter | |||
| import dataclasses | |||
| import warnings | |||
| from dataclasses import is_dataclass | |||
| from copy import deepcopy | |||
| from collections import defaultdict, OrderedDict | |||
| from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional | |||
| from typing import Tuple, Optional | |||
| from time import sleep | |||
| try: | |||
| from typing import Literal, Final | |||
| except ImportError: | |||
| from typing_extensions import Literal, Final | |||
| import os | |||
| from contextlib import contextmanager | |||
| from functools import wraps | |||
| from prettytable import PrettyTable | |||
| import numpy as np | |||
| from pathlib import Path | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||
| __all__ = [ | |||
| 'get_fn_arg_names', | |||
| 'check_fn_not_empty_params', | |||
| 'auto_param_call', | |||
| 'check_user_specific_params', | |||
| 'dataclass_to_dict', | |||
| 'match_and_substitute_params', | |||
| 'apply_to_collection', | |||
| 'nullcontext', | |||
| 'pretty_table_printer', | |||
| 'Option', | |||
| 'indice_collate_wrapper', | |||
| 'deprecated', | |||
| 'seq_len_to_mask', | |||
| 'synchronize_safe_rm', | |||
| 'synchronize_mkdir' | |||
| ] | |||
| def get_fn_arg_names(fn: Callable) -> List[str]: | |||
| r""" | |||
| 返回一个函数的所有参数的名字; | |||
| :param fn: 需要查询的函数; | |||
| :return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | |||
| """ | |||
| return list(inspect.signature(fn).parameters) | |||
| def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||
| r""" | |||
| 检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||
| 用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||
| :param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||
| :param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||
| :return: bool,表示传入的 `batch_step_fn` 是否正确; | |||
| """ | |||
| if fn is None: | |||
| return True | |||
| if not callable(fn): | |||
| return False | |||
| else: | |||
| params = inspect.signature(fn).parameters | |||
| not_default_params = {} | |||
| for _name, _param in params.items(): | |||
| if _param.default == Parameter.empty: | |||
| not_default_params[_name] = _param | |||
| return len(not_default_params) == param_num | |||
| def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
| mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
| r""" | |||
| 1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||
| 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||
| 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
| 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||
| 4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; | |||
| :param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | |||
| :param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; | |||
| :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||
| 参数值后,再传给 `fn` 进行实际的运算; | |||
| :param mapping: 一个字典,用来更改其前面的字典的键值; | |||
| :return: 返回 `fn` 运行的结果; | |||
| Examples: | |||
| >>> # 1 | |||
| >>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); | |||
| >>> batch = {"x": 20, "y": 1} | |||
| >>> output = {"pred": 0} | |||
| >>> acc = auto_param_call(loss_fn, batch, output) | |||
| >>> # 2 | |||
| >>> def test_fn(x, y, a, b=10): | |||
| >>> return x + y + a + b | |||
| >>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 | |||
| >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | |||
| >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | |||
| """ | |||
| if signature_fn is not None: | |||
| if not callable(signature_fn): | |||
| raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | |||
| _need_params = OrderedDict(inspect.signature(signature_fn).parameters) | |||
| else: | |||
| _need_params = OrderedDict(inspect.signature(fn).parameters) | |||
| _kwargs = None | |||
| for _name, _param in _need_params.items(): | |||
| if _param.kind == Parameter.VAR_POSITIONAL: | |||
| raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||
| if _param.kind == Parameter.VAR_KEYWORD: | |||
| _kwargs = (_name, _param) | |||
| if _kwargs is not None: | |||
| _need_params.pop(_kwargs[0]) | |||
| _default_params = {} | |||
| for _name, _param in _need_params.items(): | |||
| if _param.default != Parameter.empty: | |||
| _default_params[_name] = _param.default | |||
| if mapping is not None: | |||
| assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
| _has_params = {} | |||
| duplicate_names = [] | |||
| for arg in args: | |||
| assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||
| for _name, _value in arg.items(): | |||
| if mapping is not None and _name in mapping: | |||
| _name = mapping[_name] | |||
| if _name not in _has_params: | |||
| if _kwargs is not None or _name in _need_params: | |||
| _has_params[_name] = _value | |||
| # 同一参数对象在两个输入的资源中都出现,造成混淆; | |||
| elif _name in _need_params and not (_has_params[_name] is _value): | |||
| duplicate_names.append(_name) | |||
| if duplicate_names: | |||
| raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||
| # 将具有默认值但是没有被输入修改过的参数值传进去; | |||
| for _name, _value in _default_params.items(): | |||
| if _name not in _has_params: | |||
| _has_params[_name] = _value | |||
| if len(_has_params)<len(_need_params): | |||
| miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | |||
| raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||
| return fn(**_has_params) | |||
| def check_user_specific_params(user_params: Dict, fn: Callable): | |||
| """ | |||
| 该函数使用用户的输入来对指定函数的参数进行赋值; | |||
| 主要用于一些用户无法直接调用函数的情况; | |||
| 该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误; | |||
| :param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值; | |||
| :param fn: 会被调用的函数; | |||
| :return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值; | |||
| """ | |||
| fn_arg_names = get_fn_arg_names(fn) | |||
| for arg_name, arg_value in user_params.items(): | |||
| if arg_name not in fn_arg_names: | |||
| logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||
| return user_params | |||
| def dataclass_to_dict(data: "dataclass") -> Dict: | |||
| if not is_dataclass(data): | |||
| raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | |||
| _dict = dict() | |||
| for _key in data.__dataclass_fields__: | |||
| _dict[_key] = getattr(data, _key) | |||
| return _dict | |||
| def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | |||
| r""" | |||
| 用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; | |||
| 该函数应用于 `input_mapping` 和 `output_mapping`; | |||
| 对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; | |||
| 对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; | |||
| 转换的逻辑按优先级依次为: | |||
| 1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; | |||
| 2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; | |||
| 如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; | |||
| 如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; | |||
| 如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 | |||
| mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 | |||
| :param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 | |||
| :param data: 需要被转换的对象; | |||
| :return: 返回转换好的结果; | |||
| """ | |||
| if mapping is None: | |||
| return data | |||
| if callable(mapping): | |||
| # 注意我们在 `Trainer.extract_loss_from_outputs` 函数里会检查 outputs 的输出,outputs 的类型目前只支持 `Dict` 和 `dataclass`; | |||
| return mapping(data) | |||
| if not isinstance(mapping, Dict): | |||
| raise ValueError( | |||
| f"Parameter `mapping` should be of type `Dict` or `Callable`, not `{type(mapping)}`. This is caused" | |||
| f"by your `input_mapping` or `output_mapping` parameter in your `Trainer` or `Evaluator`.") | |||
| if not isinstance(data, Dict) and not is_dataclass(data) and not isinstance(data, Sequence): | |||
| raise ValueError("Parameter `data` should be type `Dict` or `dataclass` when the other parameter `mapping` is " | |||
| "type `Dict`.") | |||
| # 如果 `data` 是一个 dataclass,那么先将其转换为一个 `Dict`; | |||
| if is_dataclass(data): | |||
| data = dataclass_to_dict(data) | |||
| # 如果 `data` 是一个 List,那么我们同样先将其转换为一个 `Dict`,为 {"_0": list[0], "_1": list[1], ...}; | |||
| elif isinstance(data, Sequence): | |||
| data = {"_" + str(i): data[i] for i in range(len(data))} | |||
| _new_data = {} | |||
| for _name, _value in data.items(): | |||
| if _name in mapping: | |||
| _new_data[mapping[_name]] = _value | |||
| else: | |||
| _new_data[_name] = _value | |||
| return _new_data | |||
| def _is_namedtuple(obj: object) -> bool: | |||
| # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 | |||
| return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") | |||
| def _is_dataclass_instance(obj: object) -> bool: | |||
| # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions | |||
| return dataclasses.is_dataclass(obj) and not isinstance(obj, type) | |||
| def apply_to_collection( | |||
| data: Any, | |||
| dtype: Union[type, Any, Tuple[Union[type, Any]]], | |||
| function: Callable, | |||
| *args: Any, | |||
| wrong_dtype: Optional[Union[type, Tuple[type]]] = None, | |||
| include_none: bool = True, | |||
| **kwargs: Any, | |||
| ) -> Any: | |||
| """将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 | |||
| this function credit to: https://github.com/PyTorchLightning/pytorch-lightning | |||
| Args: | |||
| data: the collection to apply the function to | |||
| dtype: the given function will be applied to all elements of this dtype | |||
| function: the function to apply | |||
| *args: positional arguments (will be forwarded to calls of ``function``) | |||
| wrong_dtype: the given function won't be applied if this type is specified and the given collections | |||
| is of the ``wrong_dtype`` even if it is of type ``dtype`` | |||
| include_none: Whether to include an element if the output of ``function`` is ``None``. | |||
| **kwargs: keyword arguments (will be forwarded to calls of ``function``) | |||
| Returns: | |||
| The resulting collection | |||
| """ | |||
| # Breaking condition | |||
| if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): | |||
| return function(data, *args, **kwargs) | |||
| elem_type = type(data) | |||
| # Recursively apply to collection items | |||
| if isinstance(data, Mapping): | |||
| out = [] | |||
| for k, v in data.items(): | |||
| v = apply_to_collection( | |||
| v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||
| ) | |||
| if include_none or v is not None: | |||
| out.append((k, v)) | |||
| if isinstance(data, defaultdict): | |||
| return elem_type(data.default_factory, OrderedDict(out)) | |||
| return elem_type(OrderedDict(out)) | |||
| is_namedtuple = _is_namedtuple(data) | |||
| is_sequence = isinstance(data, Sequence) and not isinstance(data, str) | |||
| if is_namedtuple or is_sequence: | |||
| out = [] | |||
| for d in data: | |||
| v = apply_to_collection( | |||
| d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||
| ) | |||
| if include_none or v is not None: | |||
| out.append(v) | |||
| return elem_type(*out) if is_namedtuple else elem_type(out) | |||
| if _is_dataclass_instance(data): | |||
| # make a deepcopy of the data, | |||
| # but do not deepcopy mapped fields since the computation would | |||
| # be wasted on values that likely get immediately overwritten | |||
| fields = {} | |||
| memo = {} | |||
| for field in dataclasses.fields(data): | |||
| field_value = getattr(data, field.name) | |||
| fields[field.name] = (field_value, field.init) | |||
| memo[id(field_value)] = field_value | |||
| result = deepcopy(data, memo=memo) | |||
| # apply function to each field | |||
| for field_name, (field_value, field_init) in fields.items(): | |||
| if field_init: | |||
| v = apply_to_collection( | |||
| field_value, | |||
| dtype, | |||
| function, | |||
| *args, | |||
| wrong_dtype=wrong_dtype, | |||
| include_none=include_none, | |||
| **kwargs, | |||
| ) | |||
| if not field_init or (not include_none and v is None): # retain old value | |||
| v = getattr(data, field_name) | |||
| setattr(result, field_name, v) | |||
| return result | |||
| # data is neither of dtype, nor a collection | |||
| return data | |||
| @contextmanager | |||
| def nullcontext(): | |||
| r""" | |||
| 用来实现一个什么 dummy 的 context 上下文环境; | |||
| """ | |||
| yield | |||
| def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
| r""" | |||
| :param string: 要被截断的字符串 | |||
| :param c: 命令行列数 | |||
| :param c_size: instance或dataset field数 | |||
| :param title: 列名 | |||
| :return: 对一个过长的列进行截断的结果 | |||
| """ | |||
| avg = max(int(c / c_size / 2), len(title)) | |||
| string = str(string) | |||
| res = "" | |||
| counter = 0 | |||
| for char in string: | |||
| if ord(char) > 255: | |||
| counter += 2 | |||
| else: | |||
| counter += 1 | |||
| res += char | |||
| if counter > avg: | |||
| res = res + "..." | |||
| break | |||
| return res | |||
| def _is_iterable(value): | |||
| # 检查是否是iterable的, duck typing | |||
| try: | |||
| iter(value) | |||
| return True | |||
| except BaseException as e: | |||
| return False | |||
| def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
| r""" | |||
| :param dataset_or_ins: 传入一个dataSet或者instance | |||
| ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||
| +-----------+-----------+-----------------+ | |||
| | field_1 | field_2 | field_3 | | |||
| +-----------+-----------+-----------------+ | |||
| | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||
| +-----------+-----------+-----------------+ | |||
| :return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||
| """ | |||
| x = PrettyTable() | |||
| try: | |||
| sz = os.get_terminal_size() | |||
| column = sz.columns | |||
| row = sz.lines | |||
| except OSError: | |||
| column = 144 | |||
| row = 11 | |||
| if type(dataset_or_ins).__name__ == "DataSet": | |||
| x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||
| c_size = len(x.field_names) | |||
| for ins in dataset_or_ins: | |||
| x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||
| row -= 1 | |||
| if row < 0: | |||
| x.add_row(["..." for _ in range(c_size)]) | |||
| break | |||
| elif type(dataset_or_ins).__name__ == "Instance": | |||
| x.field_names = list(dataset_or_ins.fields.keys()) | |||
| c_size = len(x.field_names) | |||
| x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||
| else: | |||
| raise Exception("only accept DataSet and Instance") | |||
| x.align = "l" | |||
| return x | |||
| class Option(dict): | |||
| r"""a dict can treat keys as attributes""" | |||
| def __getattr__(self, item): | |||
| try: | |||
| return self.__getitem__(item) | |||
| except KeyError: | |||
| raise AttributeError(item) | |||
| def __setattr__(self, key, value): | |||
| if key.startswith('__') and key.endswith('__'): | |||
| raise AttributeError(key) | |||
| self.__setitem__(key, value) | |||
| def __delattr__(self, item): | |||
| try: | |||
| self.pop(item) | |||
| except KeyError: | |||
| raise AttributeError(item) | |||
| def __getstate__(self): | |||
| return self | |||
| def __setstate__(self, state): | |||
| self.update(state) | |||
| def indice_collate_wrapper(func): | |||
| """ | |||
| 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||
| :param func: 需要修饰的函数 | |||
| :return: | |||
| """ | |||
| def wrapper(tuple_data): | |||
| indice, ins_list = [], [] | |||
| for idx, ins in tuple_data: | |||
| indice.append(idx) | |||
| ins_list.append(ins) | |||
| return indice, func(ins_list) | |||
| return wrapper | |||
| _emitted_deprecation_warnings = set() | |||
| def deprecated(help_message: Optional[str] = None): | |||
| """Decorator to mark a function as deprecated. | |||
| Args: | |||
| help_message (`Optional[str]`): An optional message to guide the user on how to | |||
| switch to non-deprecated usage of the library. | |||
| """ | |||
| def decorator(deprecated_function: Callable): | |||
| global _emitted_deprecation_warnings | |||
| warning_msg = ( | |||
| ( | |||
| f"{deprecated_function.__name__} is deprecated and will be removed " | |||
| "in the next major version of datasets." | |||
| ) | |||
| + f" {help_message}" | |||
| if help_message | |||
| else "" | |||
| ) | |||
| @wraps(deprecated_function) | |||
| def wrapper(*args, **kwargs): | |||
| func_hash = hash(deprecated_function) | |||
| if func_hash not in _emitted_deprecation_warnings: | |||
| warnings.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||
| _emitted_deprecation_warnings.add(func_hash) | |||
| return deprecated_function(*args, **kwargs) | |||
| wrapper._decorator_name_ = "deprecated" | |||
| return wrapper | |||
| return decorator | |||
| def seq_len_to_mask(seq_len, max_len=None): | |||
| r""" | |||
| 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||
| 转变 1-d seq_len到2-d mask. | |||
| .. code-block:: | |||
| >>> seq_len = torch.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len) | |||
| >>> print(mask.size()) | |||
| torch.Size([14, 15]) | |||
| >>> seq_len = np.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len) | |||
| >>> print(mask.shape) | |||
| (14, 15) | |||
| >>> seq_len = torch.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len, max_len=100) | |||
| >>>print(mask.size()) | |||
| torch.Size([14, 100]) | |||
| :param np.ndarray,torch.LongTensor seq_len: shape将是(B,) | |||
| :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 | |||
| 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 | |||
| :return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 | |||
| """ | |||
| if isinstance(seq_len, np.ndarray): | |||
| assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
| max_len = int(max_len) if max_len else int(seq_len.max()) | |||
| broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
| mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
| else: | |||
| raise TypeError("Only support 1-d numpy.ndarray.") | |||
| return mask | |||
| def wait_to_success(fn, no=False): | |||
| while True: | |||
| sleep(0.01) | |||
| if (no and not fn()) or (not no and fn()): | |||
| break | |||
| # 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||
| # 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||
| def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||
| if path is None: | |||
| return | |||
| if isinstance(path, str): | |||
| path = Path(path) | |||
| if not path.exists(): | |||
| return | |||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
| _recursive_rm(path) | |||
| wait_to_success(path.exists, no=True) | |||
| def _recursive_rm(path: Path): | |||
| if path.is_file() or path.is_symlink(): | |||
| if path.exists(): | |||
| try: | |||
| path.unlink() | |||
| except Exception: | |||
| pass | |||
| return | |||
| for sub_path in list(path.iterdir()): | |||
| _recursive_rm(sub_path) | |||
| path.rmdir() | |||
| def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||
| """ | |||
| 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | |||
| """ | |||
| if path is None: | |||
| return | |||
| if isinstance(path, str): | |||
| path = Path(path) | |||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
| path.mkdir(parents=True, exist_ok=True) | |||
| wait_to_success(path.exists) | |||