You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 2.2 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import pytest
  2. from dataclasses import dataclass
  3. from fastNLP.core.controllers.trainer import Trainer
  4. from fastNLP.core.metrics.accuracy import Accuracy
  5. from fastNLP.core.callbacks.progress_callback import RichCallback
  6. from paddle.optimizer import Adam
  7. from paddle.io import DataLoader
  8. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  9. from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
  10. from tests.helpers.utils import magic_argv_env_context
  11. @dataclass
  12. class TrainPaddleConfig:
  13. num_labels: int = 10
  14. feature_dimension: int = 10
  15. batch_size: int = 2
  16. shuffle: bool = True
  17. evaluate_every = 2
  18. @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])])
  19. # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
  20. @pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
  21. @pytest.mark.paddle
  22. @magic_argv_env_context
  23. def test_trainer_paddle(
  24. driver,
  25. device,
  26. callbacks,
  27. n_epochs=2,
  28. ):
  29. model = PaddleNormalModel_Classification_1(
  30. num_labels=TrainPaddleConfig.num_labels,
  31. feature_dimension=TrainPaddleConfig.feature_dimension
  32. )
  33. optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
  34. train_dataloader = DataLoader(
  35. dataset=PaddleRandomMaxDataset(20, 10),
  36. batch_size=TrainPaddleConfig.batch_size,
  37. shuffle=True
  38. )
  39. val_dataloader = DataLoader(
  40. dataset=PaddleRandomMaxDataset(20, 10),
  41. batch_size=TrainPaddleConfig.batch_size,
  42. shuffle=True
  43. )
  44. train_dataloader = train_dataloader
  45. evaluate_dataloaders = val_dataloader
  46. evaluate_every = TrainPaddleConfig.evaluate_every
  47. metrics = {"acc": Accuracy(backend="paddle")}
  48. trainer = Trainer(
  49. model=model,
  50. driver=driver,
  51. device=device,
  52. optimizers=optimizers,
  53. train_dataloader=train_dataloader,
  54. evaluate_dataloaders=evaluate_dataloaders,
  55. evaluate_every=evaluate_every,
  56. input_mapping=None,
  57. output_mapping=None,
  58. metrics=metrics,
  59. n_epochs=n_epochs,
  60. callbacks=callbacks,
  61. )
  62. trainer.run()