You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 2.5 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. from typing import List
  3. import pytest
  4. from dataclasses import dataclass
  5. from fastNLP.core.controllers.trainer import Trainer
  6. from fastNLP.core.metrics.accuracy import Accuracy
  7. from fastNLP.core.callbacks.progress_callback import RichCallback
  8. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  9. from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
  10. if _NEED_IMPORT_PADDLE:
  11. from paddle.optimizer import Adam
  12. from paddle.io import DataLoader
  13. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  14. from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
  15. from tests.helpers.utils import magic_argv_env_context
  16. @dataclass
  17. class TrainPaddleConfig:
  18. num_labels: int = 10
  19. feature_dimension: int = 10
  20. batch_size: int = 2
  21. shuffle: bool = True
  22. evaluate_every = 2
  23. @pytest.mark.parametrize("device", ["cpu", 1, [0, 1]])
  24. # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
  25. @pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
  26. @pytest.mark.paddledist
  27. @magic_argv_env_context
  28. def test_trainer_paddle(
  29. device,
  30. callbacks,
  31. n_epochs=2,
  32. ):
  33. if isinstance(device, List) and USER_CUDA_VISIBLE_DEVICES not in os.environ:
  34. pytest.skip("Skip test fleet if FASTNLP_BACKEND is not set to paddle.")
  35. model = PaddleNormalModel_Classification_1(
  36. num_labels=TrainPaddleConfig.num_labels,
  37. feature_dimension=TrainPaddleConfig.feature_dimension
  38. )
  39. optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
  40. train_dataloader = DataLoader(
  41. dataset=PaddleRandomMaxDataset(20, 10),
  42. batch_size=TrainPaddleConfig.batch_size,
  43. shuffle=True
  44. )
  45. val_dataloader = DataLoader(
  46. dataset=PaddleRandomMaxDataset(20, 10),
  47. batch_size=TrainPaddleConfig.batch_size,
  48. shuffle=True
  49. )
  50. train_dataloader = train_dataloader
  51. evaluate_dataloaders = val_dataloader
  52. evaluate_every = TrainPaddleConfig.evaluate_every
  53. metrics = {"acc": Accuracy(backend="paddle")}
  54. trainer = Trainer(
  55. model=model,
  56. driver="paddle",
  57. device=device,
  58. optimizers=optimizers,
  59. train_dataloader=train_dataloader,
  60. evaluate_dataloaders=evaluate_dataloaders,
  61. evaluate_every=evaluate_every,
  62. input_mapping=None,
  63. output_mapping=None,
  64. metrics=metrics,
  65. n_epochs=n_epochs,
  66. callbacks=callbacks,
  67. )
  68. trainer.run()