|
|
|
@@ -16,6 +16,7 @@ |
|
|
|
"""train FasterRcnn and get checkpoint files.""" |
|
|
|
|
|
|
|
import os |
|
|
|
import time |
|
|
|
import argparse |
|
|
|
import random |
|
|
|
import numpy as np |
|
|
|
@@ -72,7 +73,7 @@ if __name__ == '__main__': |
|
|
|
prefix = "FasterRcnn.mindrecord" |
|
|
|
mindrecord_dir = config.mindrecord_dir |
|
|
|
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") |
|
|
|
if not os.path.exists(mindrecord_file): |
|
|
|
if rank == 0 and not os.path.exists(mindrecord_file): |
|
|
|
if not os.path.isdir(mindrecord_dir): |
|
|
|
os.makedirs(mindrecord_dir) |
|
|
|
if args_opt.dataset == "coco": |
|
|
|
@@ -90,6 +91,9 @@ if __name__ == '__main__': |
|
|
|
else: |
|
|
|
print("IMAGE_DIR or ANNO_PATH not exits.") |
|
|
|
|
|
|
|
while not os.path.exists(mindrecord_file + ".db"): |
|
|
|
time.sleep(5) |
|
|
|
|
|
|
|
if not args_opt.only_create_dataset: |
|
|
|
loss_scale = float(config.loss_scale) |
|
|
|
|
|
|
|
|