You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 2.3 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import pytest
  2. from dataclasses import dataclass
  3. from fastNLP.core.controllers.trainer import Trainer
  4. from fastNLP.core.metrics.accuracy import Accuracy
  5. from fastNLP.core.callbacks.progress_callback import RichCallback
  6. from paddle.optimizer import Adam
  7. from paddle.io import DataLoader
  8. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  9. from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
  10. from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
  11. from tests.helpers.utils import magic_argv_env_context
  12. @dataclass
  13. class TrainPaddleConfig:
  14. num_labels: int = 10
  15. feature_dimension: int = 10
  16. batch_size: int = 2
  17. shuffle: bool = True
  18. evaluate_every = 2
  19. @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])])
  20. # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
  21. @pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
  22. @pytest.mark.paddle
  23. @magic_argv_env_context
  24. def test_trainer_paddle(
  25. driver,
  26. device,
  27. callbacks,
  28. n_epochs=2,
  29. ):
  30. model = PaddleNormalModel_Classification_1(
  31. num_labels=TrainPaddleConfig.num_labels,
  32. feature_dimension=TrainPaddleConfig.feature_dimension
  33. )
  34. optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
  35. train_dataloader = DataLoader(
  36. dataset=PaddleRandomMaxDataset(20, 10),
  37. batch_size=TrainPaddleConfig.batch_size,
  38. shuffle=True
  39. )
  40. val_dataloader = DataLoader(
  41. dataset=PaddleRandomMaxDataset(20, 10),
  42. batch_size=TrainPaddleConfig.batch_size,
  43. shuffle=True
  44. )
  45. train_dataloader = train_dataloader
  46. evaluate_dataloaders = val_dataloader
  47. evaluate_every = TrainPaddleConfig.evaluate_every
  48. metrics = {"acc": Accuracy(backend="paddle")}
  49. trainer = Trainer(
  50. model=model,
  51. driver=driver,
  52. device=device,
  53. optimizers=optimizers,
  54. train_dataloader=train_dataloader,
  55. evaluate_dataloaders=evaluate_dataloaders,
  56. evaluate_every=evaluate_every,
  57. input_mapping=None,
  58. output_mapping=None,
  59. metrics=metrics,
  60. n_epochs=n_epochs,
  61. callbacks=callbacks,
  62. )
  63. trainer.run()