|
|
|
@@ -433,6 +433,12 @@ def check_data_dump(dump_file_path): |
|
|
|
expect = np.array([[8, 10, 12], [14, 16, 18]], np.float32) |
|
|
|
assert np.array_equal(output, expect) |
|
|
|
|
|
|
|
|
|
|
|
def run_train(): |
|
|
|
add = Net() |
|
|
|
add(Tensor(x), Tensor(y)) |
|
|
|
|
|
|
|
|
|
|
|
def run_saved_data_dump_test(scenario, saved_data): |
|
|
|
"""Run e2e dump on scenario, testing statistic dump""" |
|
|
|
if sys.platform != 'linux': |
|
|
|
@@ -445,8 +451,8 @@ def run_saved_data_dump_test(scenario, saved_data): |
|
|
|
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0') |
|
|
|
if os.path.isdir(dump_path): |
|
|
|
shutil.rmtree(dump_path) |
|
|
|
add = Net() |
|
|
|
add(Tensor(x), Tensor(y)) |
|
|
|
exec_network_cmd = 'cd {0}; python -c "from test_data_dump import run_train; run_train()"'.format(os.getcwd()) |
|
|
|
_ = os.system(exec_network_cmd) |
|
|
|
for _ in range(3): |
|
|
|
if not os.path.exists(dump_file_path): |
|
|
|
time.sleep(2) |
|
|
|
|