| @@ -134,12 +134,8 @@ class LossGet(Callback): | |||
| return self._loss | |||
| def train_process(device_id, epoch_size, num_classes, batch_size): | |||
| os.system("mkdir " + str(device_id)) | |||
| os.chdir(str(device_id)) | |||
| def train_process(epoch_size, num_classes, batch_size): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(device_id=device_id) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = resnet50(batch_size, num_classes) | |||
| loss = CrossEntropyLoss() | |||
| opt = Momentum(filter(lambda x: x.requires_grad, | |||
| @@ -148,34 +144,15 @@ def train_process(device_id, epoch_size, num_classes, batch_size): | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) | |||
| batch_num = dataset.get_dataset_size() | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./", | |||
| config=config_ck) | |||
| loss_cb = LossGet() | |||
| model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) | |||
| model.train(epoch_size, dataset, callbacks=[loss_cb]) | |||
| def eval(batch_size, num_classes): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(device_id=0) | |||
| net = resnet50(batch_size, num_classes) | |||
| loss = CrossEntropyLoss() | |||
| opt = Momentum(filter(lambda x: x.requires_grad, | |||
| net.get_parameters()), 0.01, 0.9) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt" | |||
| param_dict = load_checkpoint(checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| eval_dataset = create_dataset(1, training=False) | |||
| res = model.eval(eval_dataset) | |||
| print("result: ", res) | |||
| return res | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @@ -184,11 +161,7 @@ def test_resnet_cifar_1p(): | |||
| epoch_size = 1 | |||
| num_classes = 10 | |||
| batch_size = 32 | |||
| device_id = 0 | |||
| train_process(device_id, epoch_size, num_classes, batch_size) | |||
| time.sleep(3) | |||
| acc = eval(batch_size, num_classes) | |||
| os.chdir("../") | |||
| os.system("rm -rf " + str(device_id)) | |||
| acc = train_process(epoch_size, num_classes, batch_size) | |||
| os.system("rm -rf kernel_meta") | |||
| print("End training...") | |||
| assert acc['acc'] > 0.35 | |||