| @@ -57,7 +57,7 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog | |||
| class RichCallback(ProgressCallback): | |||
| """ | |||
| 在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||
| 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||
| 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||
| :param print_every: 多少个 batch 更新一次显示。 | |||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
| @@ -144,8 +144,10 @@ class RichCallback(ProgressCallback): | |||
| self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | |||
| f"Batch:{trainer.batch_idx_in_epoch}", | |||
| style=rule_style, characters=characters) | |||
| results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
| not key.startswith('_')} | |||
| if self.format_json: | |||
| self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
| self.progress_bar.console.print_json(json.dumps(results)) | |||
| else: | |||
| self.progress_bar.print(results) | |||
| @@ -165,7 +167,7 @@ class RawTextCallback(ProgressCallback): | |||
| def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||
| format_json=True): | |||
| """ | |||
| 通过向命令行打印进度的方式显示 | |||
| 通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||
| :param print_every: 多少个 batch 更新一次显示。 | |||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
| @@ -222,8 +224,10 @@ class RawTextCallback(ProgressCallback): | |||
| text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||
| logger.info(text) | |||
| results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
| not key.startswith('_')} | |||
| if self.format_json: | |||
| logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
| logger.info(json.dumps(results)) | |||
| else: | |||
| logger.info(results) | |||
| @@ -235,7 +239,7 @@ class RawTextCallback(ProgressCallback): | |||
| class TqdmCallback(ProgressCallback): | |||
| """ | |||
| 在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||
| 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||
| 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||
| :param print_every: 多少个 batch 更新一次显示。 | |||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
| @@ -309,8 +313,10 @@ class TqdmCallback(ProgressCallback): | |||
| text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||
| logger.info(text) | |||
| results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
| not key.startswith('_')} | |||
| if self.format_json: | |||
| logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
| logger.info(json.dumps(results)) | |||
| else: | |||
| logger.info(results) | |||
| @@ -630,7 +630,7 @@ def is_notebook(): | |||
| def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: | |||
| """ | |||
| 讲一个 nested 的 dict 转成 flat 的 dict,例如 | |||
| 将一个 nested 的 dict 转成 flat 的 dict,例如 | |||
| ex:: | |||
| d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} | |||
| @@ -245,8 +245,9 @@ class DataBundle: | |||
| """ | |||
| _progress_desc = progress_desc | |||
| for name, dataset in self.datasets.items(): | |||
| if _progress_desc: | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| if len(_progress_desc) == 0: | |||
| _progress_desc = 'Processing' | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| if dataset.has_field(field_name=field_name): | |||
| dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | |||
| progress_desc=progress_desc, progress_bar=progress_bar) | |||
| @@ -284,8 +285,9 @@ class DataBundle: | |||
| res = {} | |||
| _progress_desc = progress_desc | |||
| for name, dataset in self.datasets.items(): | |||
| if _progress_desc: | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| if len(_progress_desc) == 0: | |||
| _progress_desc = 'Processing' | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| if dataset.has_field(field_name=field_name): | |||
| res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | |||
| modify_fields=modify_fields, | |||
| @@ -317,8 +319,9 @@ class DataBundle: | |||
| """ | |||
| _progress_desc = progress_desc | |||
| for name, dataset in self.datasets.items(): | |||
| if _progress_desc: | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| if len(_progress_desc) == 0: | |||
| _progress_desc = 'Processing' | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | |||
| progress_desc=progress_desc) | |||
| return self | |||
| @@ -349,8 +352,9 @@ class DataBundle: | |||
| res = {} | |||
| _progress_desc = progress_desc | |||
| for name, dataset in self.datasets.items(): | |||
| if _progress_desc: | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| if len(_progress_desc) == 0: | |||
| _progress_desc = 'Processing' | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | |||
| progress_bar=progress_bar, progress_desc=progress_desc) | |||
| return res | |||