You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 2.4 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import pytest
  2. import os
  3. os.environ["FASTNLP_BACKEND"] = "paddle"
  4. from dataclasses import dataclass
  5. from fastNLP.core.controllers.trainer import Trainer
  6. from fastNLP.core.metrics.accuracy import Accuracy
  7. from fastNLP.core.callbacks.progress_callback import RichCallback
  8. from paddle.optimizer import Adam
  9. from paddle.io import DataLoader
  10. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  11. from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
  12. from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
  13. from tests.helpers.utils import magic_argv_env_context
  14. @dataclass
  15. class TrainPaddleConfig:
  16. num_labels: int = 10
  17. feature_dimension: int = 10
  18. batch_size: int = 2
  19. shuffle: bool = True
  20. evaluate_every = 2
  21. @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
  22. # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
  23. @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
  24. RichCallback(5)]])
  25. @magic_argv_env_context
  26. def test_trainer_paddle(
  27. driver,
  28. device,
  29. callbacks,
  30. n_epochs=2,
  31. ):
  32. model = PaddleNormalModel_Classification_1(
  33. num_labels=TrainPaddleConfig.num_labels,
  34. feature_dimension=TrainPaddleConfig.feature_dimension
  35. )
  36. optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
  37. train_dataloader = DataLoader(
  38. dataset=PaddleRandomMaxDataset(20, 10),
  39. batch_size=TrainPaddleConfig.batch_size,
  40. shuffle=True
  41. )
  42. val_dataloader = DataLoader(
  43. dataset=PaddleRandomMaxDataset(20, 10),
  44. batch_size=TrainPaddleConfig.batch_size,
  45. shuffle=True
  46. )
  47. train_dataloader = train_dataloader
  48. evaluate_dataloaders = val_dataloader
  49. evaluate_every = TrainPaddleConfig.evaluate_every
  50. metrics = {"acc": Accuracy(backend="paddle")}
  51. trainer = Trainer(
  52. model=model,
  53. driver=driver,
  54. device=device,
  55. optimizers=optimizers,
  56. train_dataloader=train_dataloader,
  57. evaluate_dataloaders=evaluate_dataloaders,
  58. evaluate_every=evaluate_every,
  59. input_mapping=None,
  60. output_mapping=None,
  61. metrics=metrics,
  62. n_epochs=n_epochs,
  63. callbacks=callbacks,
  64. )
  65. trainer.run()