You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer_fleet.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. 这个文件测试用户以python -m paddle.distributed.launch 启动的情况
  3. 看看有没有用pytest执行的机会
  4. python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py
  5. """
  6. import os
  7. os.environ["FASTNLP_BACKEND"] = "paddle"
  8. import sys
  9. sys.path.append("../../../")
  10. from dataclasses import dataclass
  11. from fastNLP.core.controllers.trainer import Trainer
  12. from fastNLP.core.metrics.accuracy import Accuracy
  13. from fastNLP.core.callbacks.progress_callback import RichCallback
  14. from fastNLP.core.callbacks import Callback
  15. import paddle
  16. from paddle.optimizer import Adam
  17. from paddle.io import DataLoader
  18. from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
  19. from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
  20. from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback
  21. @dataclass
  22. class MNISTTrainFleetConfig:
  23. num_labels: int = 10
  24. feature_dimension: int = 10
  25. batch_size: int = 32
  26. shuffle: bool = True
  27. validate_every = -1
  28. def test_trainer_fleet(
  29. driver,
  30. device,
  31. callbacks,
  32. n_epochs,
  33. ):
  34. model = PaddleNormalModel_Classification_1(
  35. num_labels=MNISTTrainFleetConfig.num_labels,
  36. feature_dimension=MNISTTrainFleetConfig.feature_dimension
  37. )
  38. optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
  39. train_dataloader = DataLoader(
  40. dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension),
  41. batch_size=MNISTTrainFleetConfig.batch_size,
  42. shuffle=True
  43. )
  44. val_dataloader = DataLoader(
  45. dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension),
  46. batch_size=MNISTTrainFleetConfig.batch_size,
  47. shuffle=True
  48. )
  49. train_dataloader = train_dataloader
  50. validate_dataloaders = val_dataloader
  51. validate_every = MNISTTrainFleetConfig.validate_every
  52. metrics = {"acc": Accuracy()}
  53. trainer = Trainer(
  54. model=model,
  55. driver=driver,
  56. device=device,
  57. optimizers=optimizers,
  58. train_dataloader=train_dataloader,
  59. evaluate_dataloaders=validate_dataloaders,
  60. evaluate_every=validate_every,
  61. input_mapping=None,
  62. output_mapping=None,
  63. metrics=metrics,
  64. n_epochs=n_epochs,
  65. callbacks=callbacks,
  66. output_from_new_proc="logs",
  67. )
  68. trainer.run()
  69. if __name__ == "__main__":
  70. driver = "fleet"
  71. device = [0,2,3]
  72. # driver = "paddle"
  73. # device = 2
  74. callbacks = [
  75. # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
  76. RichCallback(5),
  77. ]
  78. test_trainer_fleet(
  79. driver=driver,
  80. device=device,
  81. callbacks=callbacks,
  82. n_epochs=5,
  83. )