| @@ -75,7 +75,7 @@ def model_and_optimizers(request): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
| def test_model_checkpoint_callback_1( | def test_model_checkpoint_callback_1( | ||||
| model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
| @@ -11,12 +11,8 @@ from ...helpers.utils import Capturing | |||||
| def _assert_equal(d1, d2): | def _assert_equal(d1, d2): | ||||
| try: | try: | ||||
| if 'torch' in str(type(d1)): | if 'torch' in str(type(d1)): | ||||
| if 'float64' in str(d2.dtype): | |||||
| print(d2.dtype) | |||||
| assert (d1 == d2).all().item() | assert (d1 == d2).all().item() | ||||
| if 'oneflow' in str(type(d1)): | |||||
| if 'float64' in str(d2.dtype): | |||||
| print(d2.dtype) | |||||
| elif 'oneflow' in str(type(d1)): | |||||
| assert (d1 == d2).all().item() | assert (d1 == d2).all().item() | ||||
| else: | else: | ||||
| assert all(d1 == d2) | assert all(d1 == d2) | ||||
| @@ -67,7 +67,7 @@ def model_and_optimizers(request): | |||||
| return trainer_params | return trainer_params | ||||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上") | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
| @@ -97,7 +97,7 @@ def test_trainer_torch_without_evaluator( | |||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上") | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("save_on_rank0", [True, False]) | @pytest.mark.parametrize("save_on_rank0", [True, False]) | ||||
| @magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||