You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_paddle.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import pytest
  2. import os
  3. from typing import Any
  4. from dataclasses import dataclass
  5. from paddle.optimizer import Adam
  6. from paddle.io import DataLoader
  7. from fastNLP.core.controllers.trainer import Trainer
  8. from fastNLP.core.metrics.accuracy import Accuracy
  9. from fastNLP.core.callbacks.progress_callback import RichCallback
  10. from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
  11. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
  12. from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
  13. from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
  14. from tests.helpers.utils import magic_argv_env_context
  15. @dataclass
  16. class MNISTTrainPaddleConfig:
  17. num_labels: int = 10
  18. feature_dimension: int = 784
  19. batch_size: int = 32
  20. shuffle: bool = True
  21. validate_every = -5
  22. driver: str = "paddle"
  23. device = "gpu"
  24. @dataclass
  25. class MNISTTrainFleetConfig:
  26. num_labels: int = 10
  27. feature_dimension: int = 784
  28. batch_size: int = 32
  29. shuffle: bool = True
  30. validate_every = -5
  31. @dataclass
  32. class TrainerParameters:
  33. model: Any = None
  34. optimizers: Any = None
  35. train_dataloader: Any = None
  36. validate_dataloaders: Any = None
  37. input_mapping: Any = None
  38. output_mapping: Any = None
  39. metrics: Any = None
  40. # @pytest.fixture(params=[0], autouse=True)
  41. # def model_and_optimizers(request):
  42. # """
  43. # 初始化单卡模式的模型和优化器
  44. # """
  45. # trainer_params = TrainerParameters()
  46. # print(paddle.device.get_device())
  47. # if request.param == 0:
  48. # trainer_params.model = PaddleNormalModel_Classification(
  49. # num_labels=MNISTTrainPaddleConfig.num_labels,
  50. # feature_dimension=MNISTTrainPaddleConfig.feature_dimension
  51. # )
  52. # trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
  53. # train_dataloader = DataLoader(
  54. # dataset=PaddleDataset_MNIST("train"),
  55. # batch_size=MNISTTrainPaddleConfig.batch_size,
  56. # shuffle=True
  57. # )
  58. # val_dataloader = DataLoader(
  59. # dataset=PaddleDataset_MNIST(mode="test"),
  60. # batch_size=MNISTTrainPaddleConfig.batch_size,
  61. # shuffle=True
  62. # )
  63. # trainer_params.train_dataloader = train_dataloader
  64. # trainer_params.validate_dataloaders = val_dataloader
  65. # trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
  66. # trainer_params.metrics = {"acc": Accuracy()}
  67. # return trainer_params
  68. @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
  69. # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
  70. @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
  71. RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
  72. @magic_argv_env_context
  73. def test_trainer_paddle(
  74. # model_and_optimizers: TrainerParameters,
  75. driver,
  76. device,
  77. callbacks,
  78. n_epochs=15,
  79. ):
  80. trainer_params = TrainerParameters()
  81. trainer_params.model = PaddleNormalModel_Classification(
  82. num_labels=MNISTTrainPaddleConfig.num_labels,
  83. feature_dimension=MNISTTrainPaddleConfig.feature_dimension
  84. )
  85. trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
  86. train_dataloader = DataLoader(
  87. dataset=PaddleDataset_MNIST("train"),
  88. batch_size=MNISTTrainPaddleConfig.batch_size,
  89. shuffle=True
  90. )
  91. val_dataloader = DataLoader(
  92. dataset=PaddleDataset_MNIST(mode="test"),
  93. batch_size=MNISTTrainPaddleConfig.batch_size,
  94. shuffle=True
  95. )
  96. trainer_params.train_dataloader = train_dataloader
  97. trainer_params.validate_dataloaders = val_dataloader
  98. trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
  99. trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
  100. if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ:
  101. with pytest.raises(SystemExit) as exc:
  102. trainer = Trainer(
  103. model=trainer_params.model,
  104. driver=driver,
  105. device=device,
  106. optimizers=trainer_params.optimizers,
  107. train_dataloader=trainer_params.train_dataloader,
  108. validate_dataloaders=trainer_params.validate_dataloaders,
  109. validate_every=trainer_params.validate_every,
  110. input_mapping=trainer_params.input_mapping,
  111. output_mapping=trainer_params.output_mapping,
  112. metrics=trainer_params.metrics,
  113. n_epochs=n_epochs,
  114. callbacks=callbacks,
  115. )
  116. assert exc.value.code == 0
  117. return
  118. else:
  119. trainer = Trainer(
  120. model=trainer_params.model,
  121. driver=driver,
  122. device=device,
  123. optimizers=trainer_params.optimizers,
  124. train_dataloader=trainer_params.train_dataloader,
  125. validate_dataloaders=trainer_params.validate_dataloaders,
  126. validate_every=trainer_params.validate_every,
  127. input_mapping=trainer_params.input_mapping,
  128. output_mapping=trainer_params.output_mapping,
  129. metrics=trainer_params.metrics,
  130. n_epochs=n_epochs,
  131. callbacks=callbacks,
  132. )
  133. trainer.run()