You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 3.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import pytest
  2. import os
  3. os.environ["FASTNLP_BACKEND"] = "paddle"
  4. from typing import Any
  5. from dataclasses import dataclass
  6. from fastNLP.core.controllers.trainer import Trainer
  7. from fastNLP.core.metrics.accuracy import Accuracy
  8. from fastNLP.core.callbacks.progress_callback import RichCallback
  9. from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
  10. from paddle.optimizer import Adam
  11. from paddle.io import DataLoader
  12. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  13. from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
  14. from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
  15. from tests.helpers.utils import magic_argv_env_context
  16. @dataclass
  17. class MNISTTrainPaddleConfig:
  18. num_labels: int = 10
  19. feature_dimension: int = 784
  20. batch_size: int = 32
  21. shuffle: bool = True
  22. validate_every = -5
  23. driver: str = "paddle"
  24. device = "gpu"
  25. @dataclass
  26. class MNISTTrainFleetConfig:
  27. num_labels: int = 10
  28. feature_dimension: int = 784
  29. batch_size: int = 32
  30. shuffle: bool = True
  31. validate_every = -5
  32. @dataclass
  33. class TrainerParameters:
  34. model: Any = None
  35. optimizers: Any = None
  36. train_dataloader: Any = None
  37. validate_dataloaders: Any = None
  38. input_mapping: Any = None
  39. output_mapping: Any = None
  40. metrics: Any = None
  41. @pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
  42. # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
  43. @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
  44. RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
  45. @magic_argv_env_context
  46. def test_trainer_paddle(
  47. driver,
  48. device,
  49. callbacks,
  50. n_epochs=2,
  51. ):
  52. trainer_params = TrainerParameters()
  53. trainer_params.model = PaddleNormalModel_Classification_1(
  54. num_labels=MNISTTrainPaddleConfig.num_labels,
  55. feature_dimension=MNISTTrainPaddleConfig.feature_dimension
  56. )
  57. trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
  58. train_dataloader = DataLoader(
  59. dataset=PaddleRandomMaxDataset(6400, 10),
  60. batch_size=MNISTTrainPaddleConfig.batch_size,
  61. shuffle=True
  62. )
  63. val_dataloader = DataLoader(
  64. dataset=PaddleRandomMaxDataset(1000, 10),
  65. batch_size=MNISTTrainPaddleConfig.batch_size,
  66. shuffle=True
  67. )
  68. trainer_params.train_dataloader = train_dataloader
  69. trainer_params.validate_dataloaders = val_dataloader
  70. trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
  71. trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
  72. trainer = Trainer(
  73. model=trainer_params.model,
  74. driver=driver,
  75. device=device,
  76. optimizers=trainer_params.optimizers,
  77. train_dataloader=trainer_params.train_dataloader,
  78. validate_dataloaders=trainer_params.validate_dataloaders,
  79. validate_every=trainer_params.validate_every,
  80. input_mapping=trainer_params.input_mapping,
  81. output_mapping=trainer_params.output_mapping,
  82. metrics=trainer_params.metrics,
  83. n_epochs=n_epochs,
  84. callbacks=callbacks,
  85. )
  86. trainer.run()