|
|
|
@@ -66,6 +66,31 @@ class SaveCheckpoint(ModelCheckpoint): |
|
|
|
|
|
|
|
if self.check_monitor_top_k(trainer, current): |
|
|
|
self._update_best_and_save(current, trainer, monitor_candidates) |
|
|
|
best_model_value = max([float(item) for item in list(self.best_k_models.values())]) |
|
|
|
# 保存版本信息(准确率等)到txt中 |
|
|
|
if not os.path.exists('./logs/default/version_info.txt'): |
|
|
|
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: |
|
|
|
f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n') |
|
|
|
else: |
|
|
|
with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: |
|
|
|
info_list = f.readlines() |
|
|
|
info_list = [item.strip('\n').split(' ') for item in info_list] |
|
|
|
# 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 |
|
|
|
info_list = list(map(list, zip(*info_list))) |
|
|
|
if self.dirpath.split('\\')[1] in info_list[0]: |
|
|
|
for cou in range(len(info_list[0])): |
|
|
|
if self.dirpath.split('\\')[1] == info_list[0][cou]: |
|
|
|
info_list[1][cou] = str(best_model_value) |
|
|
|
else: |
|
|
|
info_list[0].append(self.dirpath.split('\\')[1]) |
|
|
|
info_list[1].append(str(best_model_value)) |
|
|
|
# 对list进行转置 |
|
|
|
info_list = list(map(list, zip(*info_list))) |
|
|
|
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: |
|
|
|
for line in info_list: |
|
|
|
line = " ".join(line) |
|
|
|
f.write(line + '\n') |
|
|
|
# 每次更新ckpt文件后, 将其存放到另一个位置 |
|
|
|
if self.path_final_save is not None: |
|
|
|
zip_dir('./logs', './logs.zip') |
|
|
|
if os.path.exists(self.path_final_save + '/logs.zip'): |
|
|
|
|