|
|
|
@@ -32,8 +32,6 @@ from src.nets import net_factory |
|
|
|
from src.utils import learning_rates
|
|
|
|
|
|
|
|
set_seed(1)
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
|
|
|
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
|
|
|
|
|
|
|
|
|
|
|
|
class BuildTrainNetwork(nn.Cell):
|
|
|
|
@@ -77,6 +75,8 @@ def parse_args(): |
|
|
|
parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model')
|
|
|
|
|
|
|
|
# train
|
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
|
|
|
|
help='device where the code will be implemented. (Default: Ascend)')
|
|
|
|
parser.add_argument('--is_distributed', action='store_true', help='distributed training')
|
|
|
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
|
|
|
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
|
|
|
@@ -90,6 +90,12 @@ def parse_args(): |
|
|
|
def train():
|
|
|
|
args = parse_args()
|
|
|
|
|
|
|
|
if args.device_target == "CPU":
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
|
|
|
else:
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
|
|
|
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
|
|
|
|
|
|
|
|
# init multicards training
|
|
|
|
if args.is_distributed:
|
|
|
|
init()
|
|
|
|
@@ -150,7 +156,8 @@ def train(): |
|
|
|
|
|
|
|
# loss scale
|
|
|
|
manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
|
|
|
model = Model(train_net, optimizer=opt, amp_level="O3", loss_scale_manager=manager_loss_scale)
|
|
|
|
amp_level = "O0" if args.device_target == "CPU" else "O3"
|
|
|
|
model = Model(train_net, optimizer=opt, amp_level=amp_level, loss_scale_manager=manager_loss_scale)
|
|
|
|
|
|
|
|
# callback for saving ckpts
|
|
|
|
time_cb = TimeMonitor(data_size=iters_per_epoch)
|
|
|
|
|