|
|
|
@@ -193,7 +193,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num): |
|
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_single |
|
|
|
def test_bert_thor_mlperf_8p(): |
|
|
|
def test_bert_thor_8p(): |
|
|
|
"""test bert thor mlperf 8p""" |
|
|
|
q = Queue() |
|
|
|
device_num = 8 |
|
|
|
@@ -234,12 +234,12 @@ def test_bert_thor_mlperf_8p(): |
|
|
|
os.system("rm -rf " + str(i)) |
|
|
|
|
|
|
|
print("End training...") |
|
|
|
assert mean_cost < 78 |
|
|
|
assert mean_cost < 66 |
|
|
|
assert mean_loss < 8.125 |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
begin = time.time() |
|
|
|
test_bert_thor_mlperf_8p() |
|
|
|
test_bert_thor_8p() |
|
|
|
end = time.time() |
|
|
|
print("time span is", end - begin, flush=True) |