You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dist_trainer.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import unittest
  2. import numpy as np
  3. import torch.cuda
  4. from fastNLP import DataSet
  5. from fastNLP import Instance
  6. from fastNLP import CrossEntropyLoss, BCELoss
  7. from fastNLP import SGD
  8. from fastNLP.core.dist_trainer import DistTrainer, get_local_rank
  9. from fastNLP.models.base_model import NaiveClassifier
  10. import shutil
  11. import os
  12. import subprocess
  13. from argparse import ArgumentParser
  14. from fastNLP.core.callback import EchoCallback
  15. from fastNLP import AccuracyMetric
  16. def prepare_fake_dataset():
  17. mean = np.array([-3, -3])
  18. cov = np.array([[1, 0], [0, 1]])
  19. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  20. mean = np.array([3, 3])
  21. cov = np.array([[1, 0], [0, 1]])
  22. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  23. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] +
  24. [Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B])
  25. return data_set
  26. def prepare_fake_dataset2(*args, size=100):
  27. ys = np.random.randint(4, size=100, dtype=np.int64)
  28. data = {'y': ys}
  29. for arg in args:
  30. data[arg] = np.random.randn(size, 5)
  31. return DataSet(data=data)
  32. def set_rng_seed(seed):
  33. np.random.seed(seed)
  34. def prepare_env():
  35. def prepare_fake_dataset():
  36. mean = np.array([-3, -3])
  37. cov = np.array([[1, 0], [0, 1]])
  38. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  39. mean = np.array([3, 3])
  40. cov = np.array([[1, 0], [0, 1]])
  41. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  42. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
  43. [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
  44. return data_set
  45. data_set = prepare_fake_dataset()
  46. data_set.set_input("x")
  47. data_set.set_target("y")
  48. model = NaiveClassifier(2, 1)
  49. return data_set, model
  50. class TestDistTrainer(unittest.TestCase):
  51. save_path = './save_cp'
  52. def run1(self):
  53. # test distributed training
  54. print('local rank', get_local_rank())
  55. set_rng_seed(100)
  56. data_set = prepare_fake_dataset()
  57. data_set.set_input("x", flag=True)
  58. data_set.set_target("y", flag=True)
  59. model = NaiveClassifier(2, 2)
  60. trainer = DistTrainer(
  61. model=model, train_data=data_set, optimizer=SGD(lr=0.1),
  62. loss=CrossEntropyLoss(pred="predict", target="y"),
  63. batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path,
  64. )
  65. trainer.train()
  66. """
  67. # 应该正确运行
  68. """
  69. if trainer.is_master and os.path.exists(self.save_path):
  70. shutil.rmtree(self.save_path)
  71. def run2(self):
  72. # test fp16 with distributed training
  73. print('local rank', get_local_rank())
  74. set_rng_seed(100)
  75. data_set = prepare_fake_dataset()
  76. data_set.set_input("x", flag=True)
  77. data_set.set_target("y", flag=True)
  78. model = NaiveClassifier(2, 2)
  79. trainer = DistTrainer(
  80. model=model, train_data=data_set, optimizer=SGD(lr=0.1),
  81. loss=CrossEntropyLoss(pred="predict", target="y"),
  82. batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path,
  83. fp16='O1'
  84. )
  85. trainer.train()
  86. """
  87. # 应该正确运行
  88. """
  89. if trainer.is_master and os.path.exists(self.save_path):
  90. shutil.rmtree(self.save_path)
  91. def run3(self):
  92. set_rng_seed(100)
  93. data_set, model = prepare_env()
  94. trainer = DistTrainer(
  95. data_set, model, optimizer=None,
  96. loss=BCELoss(pred="predict", target="y"),
  97. n_epochs=3, print_every=50,
  98. callbacks_all=[EchoCallback('callbacks_all')],
  99. callbacks_master=[EchoCallback('callbacks_master')]
  100. )
  101. trainer.train()
  102. def run4(self):
  103. set_rng_seed(100)
  104. data_set, model = prepare_env()
  105. train_set, dev_set = data_set.split(0.3)
  106. model = NaiveClassifier(2, 1)
  107. trainer = DistTrainer(
  108. train_set, model, optimizer=SGD(lr=0.1),
  109. loss=BCELoss(pred="predict", target="y"),
  110. batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set,
  111. metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
  112. )
  113. trainer.train()
  114. """
  115. # 应该正确运行
  116. """
  117. def run_dist(self, run_id):
  118. if torch.cuda.is_available():
  119. ngpu = min(2, torch.cuda.device_count())
  120. path = __file__
  121. cmd = ['python', '-m', 'torch.distributed.launch',
  122. '--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
  123. print(' '.join(cmd))
  124. subprocess.check_call(cmd)
  125. def test_normal_run(self):
  126. self.run_dist(1)
  127. def no_test_fp16(self):
  128. self.run_dist(2)
  129. def test_callback(self):
  130. self.run_dist(3)
  131. def test_dev_data(self):
  132. self.run_dist(4)
  133. if __name__ == '__main__':
  134. runner = TestDistTrainer()
  135. parser = ArgumentParser()
  136. parser.add_argument('--test', type=int)
  137. args, _ = parser.parse_known_args()
  138. if args.test and hasattr(runner, 'run%s'%args.test):
  139. getattr(runner, 'run%s'%args.test)()