| @@ -415,7 +415,7 @@ def test_pynative_resnet50(): | |||||
| train_network.set_train() | train_network.set_train() | ||||
| step = 0 | step = 0 | ||||
| max_step = 20 | |||||
| max_step = 21 | |||||
| exceed_num = 0 | exceed_num = 0 | ||||
| data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size) | data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size) | ||||
| for element in data_set.create_dict_iterator(): | for element in data_set.create_dict_iterator(): | ||||
| @@ -431,7 +431,7 @@ def test_pynative_resnet50(): | |||||
| end_time = time.time() | end_time = time.time() | ||||
| cost_time = end_time - start_time | cost_time = end_time - start_time | ||||
| print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) | print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) | ||||
| if step > 1 and cost_time > 0.23: | |||||
| if step > 1 and cost_time > 0.25: | |||||
| exceed_num = exceed_num + 1 | exceed_num = exceed_num + 1 | ||||
| assert exceed_num < 10 | |||||
| assert exceed_num < 20 | |||||