You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 2.3 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import pytest
  2. from dataclasses import dataclass
  3. from fastNLP.core.controllers.trainer import Trainer
  4. from fastNLP.core.metrics.accuracy import Accuracy
  5. from fastNLP.core.callbacks.progress_callback import RichCallback
  6. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  7. if _NEED_IMPORT_PADDLE:
  8. from paddle.optimizer import Adam
  9. from paddle.io import DataLoader
  10. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  11. from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
  12. from tests.helpers.utils import magic_argv_env_context
  13. @dataclass
  14. class TrainPaddleConfig:
  15. num_labels: int = 10
  16. feature_dimension: int = 10
  17. batch_size: int = 2
  18. shuffle: bool = True
  19. evaluate_every = 2
  20. @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])])
  21. # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
  22. @pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
  23. @pytest.mark.paddledist
  24. @magic_argv_env_context
  25. def test_trainer_paddle(
  26. driver,
  27. device,
  28. callbacks,
  29. n_epochs=2,
  30. ):
  31. model = PaddleNormalModel_Classification_1(
  32. num_labels=TrainPaddleConfig.num_labels,
  33. feature_dimension=TrainPaddleConfig.feature_dimension
  34. )
  35. optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
  36. train_dataloader = DataLoader(
  37. dataset=PaddleRandomMaxDataset(20, 10),
  38. batch_size=TrainPaddleConfig.batch_size,
  39. shuffle=True
  40. )
  41. val_dataloader = DataLoader(
  42. dataset=PaddleRandomMaxDataset(20, 10),
  43. batch_size=TrainPaddleConfig.batch_size,
  44. shuffle=True
  45. )
  46. train_dataloader = train_dataloader
  47. evaluate_dataloaders = val_dataloader
  48. evaluate_every = TrainPaddleConfig.evaluate_every
  49. metrics = {"acc": Accuracy(backend="paddle")}
  50. trainer = Trainer(
  51. model=model,
  52. driver=driver,
  53. device=device,
  54. optimizers=optimizers,
  55. train_dataloader=train_dataloader,
  56. evaluate_dataloaders=evaluate_dataloaders,
  57. evaluate_every=evaluate_every,
  58. input_mapping=None,
  59. output_mapping=None,
  60. metrics=metrics,
  61. n_epochs=n_epochs,
  62. callbacks=callbacks,
  63. )
  64. trainer.run()