|
|
|
@@ -18,6 +18,7 @@ |
|
|
|
import os |
|
|
|
import argparse |
|
|
|
import random |
|
|
|
import ast |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
@@ -30,7 +31,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
from mindspore.nn import SGD |
|
|
|
import mindspore.dataset.engine as de |
|
|
|
|
|
|
|
from src.MaskRcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 |
|
|
|
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 |
|
|
|
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet |
|
|
|
from src.config import config |
|
|
|
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset |
|
|
|
@@ -41,11 +42,11 @@ np.random.seed(1) |
|
|
|
de.config.set_seed(1) |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="MaskRcnn training") |
|
|
|
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " |
|
|
|
"Mindrecord, default is false.") |
|
|
|
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") |
|
|
|
parser.add_argument("--do_train", type=bool, default=True, help="Do train or not, default is true.") |
|
|
|
parser.add_argument("--do_eval", type=bool, default=False, help="Do eval or not, default is false.") |
|
|
|
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create " |
|
|
|
"Mindrecord, default is false.") |
|
|
|
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default is false.") |
|
|
|
parser.add_argument("--do_train", type=ast.literal_eval, default=True, help="Do train or not, default is true.") |
|
|
|
parser.add_argument("--do_eval", type=ast.literal_eval, default=False, help="Do eval or not, default is false.") |
|
|
|
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") |
|
|
|
parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.") |
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") |
|
|
|
|