|
|
|
@@ -18,6 +18,7 @@ import time |
|
|
|
from mindspore.train.callback import Callback |
|
|
|
from mindspore import context |
|
|
|
from mindspore.train import ParallelMode |
|
|
|
from mindspore.communication.management import get_rank |
|
|
|
|
|
|
|
def add_write(file_path, out_str): |
|
|
|
""" |
|
|
|
@@ -52,7 +53,14 @@ class LossCallBack(Callback): |
|
|
|
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() |
|
|
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 |
|
|
|
cur_num = cb_params.cur_step_num |
|
|
|
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True) |
|
|
|
rank_id = 0 |
|
|
|
parallel_mode = context.get_auto_parallel_context("parallel_mode") |
|
|
|
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, |
|
|
|
ParallelMode.DATA_PARALLEL): |
|
|
|
rank_id = get_rank() |
|
|
|
|
|
|
|
print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch, |
|
|
|
wide_loss, deep_loss, flush=True) |
|
|
|
|
|
|
|
# raise ValueError |
|
|
|
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None: |
|
|
|
@@ -99,13 +107,18 @@ class EvalCallBack(Callback): |
|
|
|
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): |
|
|
|
context.set_auto_parallel_context(strategy_ckpt_save_file="", |
|
|
|
strategy_ckpt_load_file="./strategy_train.ckpt") |
|
|
|
rank_id = 0 |
|
|
|
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, |
|
|
|
ParallelMode.DATA_PARALLEL): |
|
|
|
rank_id = get_rank() |
|
|
|
start_time = time.time() |
|
|
|
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix)) |
|
|
|
end_time = time.time() |
|
|
|
eval_time = int(end_time - start_time) |
|
|
|
|
|
|
|
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) |
|
|
|
out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time) |
|
|
|
out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\ |
|
|
|
format(time_str, rank_id, out.values(), eval_time) |
|
|
|
print(out_str) |
|
|
|
self.eval_values = out.values() |
|
|
|
add_write(self.eval_file_name, out_str) |