| @@ -193,7 +193,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num): | |||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| def test_bert_thor_mlperf_8p(): | |||||
| def test_bert_thor_8p(): | |||||
| """test bert thor mlperf 8p""" | """test bert thor mlperf 8p""" | ||||
| q = Queue() | q = Queue() | ||||
| device_num = 8 | device_num = 8 | ||||
| @@ -234,12 +234,12 @@ def test_bert_thor_mlperf_8p(): | |||||
| os.system("rm -rf " + str(i)) | os.system("rm -rf " + str(i)) | ||||
| print("End training...") | print("End training...") | ||||
| assert mean_cost < 78 | |||||
| assert mean_cost < 66 | |||||
| assert mean_loss < 8.125 | assert mean_loss < 8.125 | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| begin = time.time() | begin = time.time() | ||||
| test_bert_thor_mlperf_8p() | |||||
| test_bert_thor_8p() | |||||
| end = time.time() | end = time.time() | ||||
| print("time span is", end - begin, flush=True) | print("time span is", end - begin, flush=True) | ||||