|
|
|
@@ -18,6 +18,7 @@ Functional Cells used in Bert finetune and evaluation. |
|
|
|
""" |
|
|
|
|
|
|
|
import os |
|
|
|
import math |
|
|
|
import numpy as np |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import log as logger |
|
|
|
@@ -90,15 +91,14 @@ class LossCallBack(Callback): |
|
|
|
Args: |
|
|
|
per_print_times (int): Print loss every times. Default: 1. |
|
|
|
""" |
|
|
|
def __init__(self, per_print_times=1): |
|
|
|
def __init__(self, dataset_size=1): |
|
|
|
super(LossCallBack, self).__init__() |
|
|
|
if not isinstance(per_print_times, int) or per_print_times < 0: |
|
|
|
raise ValueError("print_step must be int and >= 0") |
|
|
|
self._per_print_times = per_print_times |
|
|
|
self._dataset_size = dataset_size |
|
|
|
def step_end(self, run_context): |
|
|
|
cb_params = run_context.original_args() |
|
|
|
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, |
|
|
|
str(cb_params.net_outputs))) |
|
|
|
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size) |
|
|
|
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" |
|
|
|
.format(epoch_num, "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs))) |
|
|
|
|
|
|
|
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): |
|
|
|
""" |
|
|
|
|