You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import torch.nn.functional as F
  2. import torch
  3. import random
  4. import numpy as np
  5. from fastNLP import Const
  6. from fastNLP import CrossEntropyLoss
  7. from fastNLP import AccuracyMetric
  8. from fastNLP import Tester
  9. import os
  10. from fastNLP import logger
  11. def should_mask(name, t=''):
  12. if 'bias' in name:
  13. return False
  14. if 'embedding' in name:
  15. splited = name.split('.')
  16. if splited[-1]!='weight':
  17. return False
  18. if 'embedding' in splited[-2]:
  19. return False
  20. if 'c0' in name:
  21. return False
  22. if 'h0' in name:
  23. return False
  24. if 'output' in name and t not in name:
  25. return False
  26. return True
  27. def get_init_mask(model):
  28. init_masks = {}
  29. for name, param in model.named_parameters():
  30. if should_mask(name):
  31. init_masks[name+'.mask'] = torch.ones_like(param)
  32. # logger.info(init_masks[name+'.mask'].requires_grad)
  33. return init_masks
  34. def set_seed(seed):
  35. random.seed(seed)
  36. np.random.seed(seed+100)
  37. torch.manual_seed(seed+200)
  38. torch.cuda.manual_seed_all(seed+300)
  39. def get_parameters_size(model):
  40. result = {}
  41. for name,p in model.state_dict().items():
  42. result[name] = p.size()
  43. return result
  44. def prune_by_proportion_model(model,proportion,task):
  45. # print('this time prune to ',proportion*100,'%')
  46. for name, p in model.named_parameters():
  47. # print(name)
  48. if not should_mask(name,task):
  49. continue
  50. tensor = p.data.cpu().numpy()
  51. index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy())
  52. # print(name,'alive count',len(index[0]))
  53. alive = tensor[index]
  54. # print('p and mask size:',p.size(),print(model.mask[task][name+'.mask'].size()))
  55. percentile_value = np.percentile(abs(alive), (1 - proportion) * 100)
  56. # tensor = p
  57. # index = torch.nonzero(model.mask[task][name+'.mask'])
  58. # # print('nonzero len',index)
  59. # alive = tensor[index]
  60. # print('alive size:',alive.shape)
  61. # prune_by_proportion_model()
  62. # percentile_value = torch.topk(abs(alive), int((1-proportion)*len(index[0]))).values
  63. # print('the',(1-proportion)*len(index[0]),'th big')
  64. # print('threshold:',percentile_value)
  65. prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value)
  66. # for
  67. def prune_by_proportion_model_global(model,proportion,task):
  68. # print('this time prune to ',proportion*100,'%')
  69. alive = None
  70. for name, p in model.named_parameters():
  71. # print(name)
  72. if not should_mask(name,task):
  73. continue
  74. tensor = p.data.cpu().numpy()
  75. index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy())
  76. # print(name,'alive count',len(index[0]))
  77. if alive is None:
  78. alive = tensor[index]
  79. else:
  80. alive = np.concatenate([alive,tensor[index]],axis=0)
  81. percentile_value = np.percentile(abs(alive), (1 - proportion) * 100)
  82. for name, p in model.named_parameters():
  83. if should_mask(name,task):
  84. prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value)
  85. def prune_by_threshold_parameter(p, mask, threshold):
  86. p_abs = torch.abs(p)
  87. new_mask = (p_abs > threshold).float()
  88. # print(mask)
  89. mask[:]*=new_mask
  90. def one_time_train_and_prune_single_task(trainer,PRUNE_PER,
  91. optimizer_init_state_dict=None,
  92. model_init_state_dict=None,
  93. is_global=None,
  94. ):
  95. from fastNLP import Trainer
  96. trainer.optimizer.load_state_dict(optimizer_init_state_dict)
  97. trainer.model.load_state_dict(model_init_state_dict)
  98. # print('metrics:',metrics.__dict__)
  99. # print('loss:',loss.__dict__)
  100. # print('trainer input:',task.train_set.get_input_name())
  101. # trainer = Trainer(model=model, train_data=task.train_set, dev_data=task.dev_set, loss=loss, metrics=metrics,
  102. # optimizer=optimizer, n_epochs=EPOCH, batch_size=BATCH, device=device,callbacks=callbacks)
  103. trainer.train(load_best_model=True)
  104. # tester = Tester(task.train_set, model, metrics, BATCH, device=device, verbose=1,use_tqdm=False)
  105. # print('FOR DEBUG: test train_set:',tester.test())
  106. # print('**'*20)
  107. # if task.test_set:
  108. # tester = Tester(task.test_set, model, metrics, BATCH, device=device, verbose=1)
  109. # tester.test()
  110. if is_global:
  111. prune_by_proportion_model_global(trainer.model, PRUNE_PER, trainer.model.now_task)
  112. else:
  113. prune_by_proportion_model(trainer.model, PRUNE_PER, trainer.model.now_task)
  114. # def iterative_train_and_prune_single_task(get_trainer,ITER,PRUNE,is_global=False,save_path=None):
  115. def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None):
  116. '''
  117. :param trainer:
  118. :param ITER:
  119. :param PRUNE:
  120. :param is_global:
  121. :param save_path: should be a dictionary which will be filled with mask and state dict
  122. :return:
  123. '''
  124. from fastNLP import Trainer
  125. import torch
  126. import math
  127. import copy
  128. PRUNE = args.prune
  129. ITER = args.iter
  130. trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
  131. optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict())
  132. model_init_state_dict = copy.deepcopy(trainer.model.state_dict())
  133. if save_path is not None:
  134. if not os.path.exists(save_path):
  135. os.makedirs(save_path)
  136. # if not os.path.exists(os.path.join(save_path, 'model_init.pkl')):
  137. # f = open(os.path.join(save_path, 'model_init.pkl'), 'wb')
  138. # torch.save(trainer.model.state_dict(),f)
  139. mask_count = 0
  140. model = trainer.model
  141. task = trainer.model.now_task
  142. for name, p in model.mask[task].items():
  143. mask_count += torch.sum(p).item()
  144. init_mask_count = mask_count
  145. logger.info('init mask count:{}'.format(mask_count))
  146. # logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count,
  147. # mask_count / init_mask_count * 100))
  148. prune_per_iter = math.pow(PRUNE, 1 / ITER)
  149. for i in range(ITER):
  150. trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
  151. one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict)
  152. if save_path is not None:
  153. f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb')
  154. torch.save(model.mask[task],f)
  155. mask_count = 0
  156. for name, p in model.mask[task].items():
  157. mask_count += torch.sum(p).item()
  158. logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100))
  159. def get_appropriate_cuda(task_scale='s'):
  160. if task_scale not in {'s','m','l'}:
  161. logger.info('task scale wrong!')
  162. exit(2)
  163. import pynvml
  164. pynvml.nvmlInit()
  165. total_cuda_num = pynvml.nvmlDeviceGetCount()
  166. for i in range(total_cuda_num):
  167. logger.info(i)
  168. handle = pynvml.nvmlDeviceGetHandleByIndex(i) # 这里的0是GPU id
  169. memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
  170. utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
  171. logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',utilizationInfo.gpu)
  172. if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu <0.2:
  173. logger.info(i,memInfo.used / memInfo.total)
  174. return 'cuda:'+str(i)
  175. if task_scale=='s':
  176. max_memory=2000
  177. elif task_scale=='m':
  178. max_memory=6000
  179. else:
  180. max_memory = 9000
  181. max_id = -1
  182. for i in range(total_cuda_num):
  183. handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 这里的0是GPU id
  184. memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
  185. utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
  186. if max_memory < memInfo.free:
  187. max_memory = memInfo.free
  188. max_id = i
  189. if id == -1:
  190. logger.info('no appropriate gpu, wait!')
  191. exit(2)
  192. return 'cuda:'+str(max_id)
  193. # if memInfo.used / memInfo.total < 0.5:
  194. # return
  195. def print_mask(mask_dict):
  196. def seq_mul(*X):
  197. res = 1
  198. for x in X:
  199. res*=x
  200. return res
  201. for name,p in mask_dict.items():
  202. total_size = seq_mul(*p.size())
  203. unmasked_size = len(np.nonzero(p))
  204. print(name,':',unmasked_size,'/',total_size,'=',unmasked_size/total_size*100,'%')
  205. print()
  206. def check_words_same(dataset_1,dataset_2,field_1,field_2):
  207. if len(dataset_1[field_1]) != len(dataset_2[field_2]):
  208. logger.info('CHECK: example num not same!')
  209. return False
  210. for i, words in enumerate(dataset_1[field_1]):
  211. if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]):
  212. logger.info('CHECK {} th example length not same'.format(i))
  213. logger.info('1:{}'.format(dataset_1[field_1][i]))
  214. logger.info('2:'.format(dataset_2[field_2][i]))
  215. return False
  216. # for j,w in enumerate(words):
  217. # if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]:
  218. # print('CHECK', i, 'th example has words different!')
  219. # print('1:',dataset_1[field_1][i])
  220. # print('2:',dataset_2[field_2][i])
  221. # return False
  222. logger.info('CHECK: totally same!')
  223. return True
  224. def get_now_time():
  225. import time
  226. from datetime import datetime, timezone, timedelta
  227. dt = datetime.utcnow()
  228. # print(dt)
  229. tzutc_8 = timezone(timedelta(hours=8))
  230. local_dt = dt.astimezone(tzutc_8)
  231. result = ("_{}_{}_{}__{}_{}_{}".format(local_dt.year, local_dt.month, local_dt.day, local_dt.hour, local_dt.minute,
  232. local_dt.second))
  233. return result
  234. def get_bigrams(words):
  235. result = []
  236. for i,w in enumerate(words):
  237. if i!=len(words)-1:
  238. result.append(words[i]+words[i+1])
  239. else:
  240. result.append(words[i]+'<end>')
  241. return result
  242. def print_info(*inp,islog=False,sep=' '):
  243. from fastNLP import logger
  244. if islog:
  245. print(*inp,sep=sep)
  246. else:
  247. inp = sep.join(map(str,inp))
  248. logger.info(inp)
  249. def better_init_rnn(rnn,coupled=False):
  250. import torch.nn as nn
  251. if coupled:
  252. repeat_size = 3
  253. else:
  254. repeat_size = 4
  255. # print(list(rnn.named_parameters()))
  256. if hasattr(rnn,'num_layers'):
  257. for i in range(rnn.num_layers):
  258. nn.init.orthogonal(getattr(rnn,'weight_ih_l'+str(i)).data)
  259. weight_hh_data = torch.eye(rnn.hidden_size)
  260. weight_hh_data = weight_hh_data.repeat(1, repeat_size)
  261. with torch.no_grad():
  262. getattr(rnn,'weight_hh_l'+str(i)).set_(weight_hh_data)
  263. nn.init.constant(getattr(rnn,'bias_ih_l'+str(i)).data, val=0)
  264. nn.init.constant(getattr(rnn,'bias_hh_l'+str(i)).data, val=0)
  265. if rnn.bidirectional:
  266. for i in range(rnn.num_layers):
  267. nn.init.orthogonal(getattr(rnn, 'weight_ih_l' + str(i)+'_reverse').data)
  268. weight_hh_data = torch.eye(rnn.hidden_size)
  269. weight_hh_data = weight_hh_data.repeat(1, repeat_size)
  270. with torch.no_grad():
  271. getattr(rnn, 'weight_hh_l' + str(i)+'_reverse').set_(weight_hh_data)
  272. nn.init.constant(getattr(rnn, 'bias_ih_l' + str(i)+'_reverse').data, val=0)
  273. nn.init.constant(getattr(rnn, 'bias_hh_l' + str(i)+'_reverse').data, val=0)
  274. else:
  275. nn.init.orthogonal(rnn.weight_ih.data)
  276. weight_hh_data = torch.eye(rnn.hidden_size)
  277. weight_hh_data = weight_hh_data.repeat(repeat_size,1)
  278. with torch.no_grad():
  279. rnn.weight_hh.set_(weight_hh_data)
  280. # The bias is just set to zero vectors.
  281. print('rnn param size:{},{}'.format(rnn.weight_hh.size(),type(rnn)))
  282. if rnn.bias:
  283. nn.init.constant(rnn.bias_ih.data, val=0)
  284. nn.init.constant(rnn.bias_hh.data, val=0)
  285. # print(list(rnn.named_parameters()))