You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Data operations, will be used in train.py and eval.py
  17. """
  18. import os
  19. import math
  20. import random
  21. import codecs
  22. from pathlib import Path
  23. import numpy as np
  24. import pandas as pd
  25. import mindspore.dataset as ds
  26. class Generator():
  27. def __init__(self, input_list):
  28. self.input_list = input_list
  29. def __getitem__(self, item):
  30. return np.array(self.input_list[item][0], dtype=np.int32), np.array(self.input_list[item][1], dtype=np.int32)
  31. def __len__(self):
  32. return len(self.input_list)
  33. class DataProcessor:
  34. """
  35. preprocess dataset
  36. """
  37. def get_dict_len(self):
  38. """
  39. get number of different words in the whole dataset
  40. """
  41. if self.doConvert:
  42. return len(self.Vocab)
  43. return -1
  44. def collect_weight(self, glove_path, embed_size):
  45. """ collect weight """
  46. vocab_size = self.get_dict_len()
  47. embedding_index = {}
  48. f = open(glove_path)
  49. for line in f:
  50. values = line.split()
  51. word = values[0]
  52. vec = np.array(values[1:], dtype='float32')
  53. embedding_index[word] = vec
  54. weight_np = np.zeros((vocab_size, embed_size)).astype(np.float32)
  55. for word, vec in embedding_index.items():
  56. try:
  57. index = self.Vocab[word]
  58. except KeyError:
  59. continue
  60. weight_np[index, :] = vec
  61. return weight_np
  62. def create_train_dataset(self, epoch_size, batch_size, collect_weight=False, glove_path='', embed_size=50):
  63. if collect_weight:
  64. weight_np = self.collect_weight(glove_path, embed_size)
  65. np.savetxt('./weight.txt', weight_np)
  66. dataset = ds.GeneratorDataset(source=Generator(input_list=self.train),
  67. column_names=["data", "label"], shuffle=False)
  68. dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
  69. return dataset
  70. def create_test_dataset(self, batch_size):
  71. dataset = ds.GeneratorDataset(source=Generator(input_list=self.test),
  72. column_names=["data", "label"], shuffle=False)
  73. dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
  74. return dataset
  75. class MovieReview(DataProcessor):
  76. """
  77. preprocess MovieReview dataset
  78. """
  79. def __init__(self, root_dir, maxlen, split):
  80. """
  81. input:
  82. root_dir: the root directory path of the MR dataset
  83. maxlen: set the max length of the sentence
  84. split: set the ratio of training set to testing set
  85. rank: the logic order of the worker
  86. size: the worker num
  87. """
  88. self.path = root_dir
  89. self.feelMap = {
  90. 'neg': 0,
  91. 'pos': 1
  92. }
  93. self.files = []
  94. self.doConvert = False
  95. mypath = Path(self.path)
  96. if not mypath.exists() or not mypath.is_dir():
  97. print("please check the root_dir!")
  98. raise ValueError
  99. # walk through the root_dir
  100. for root, _, filename in os.walk(self.path):
  101. for each in filename:
  102. self.files.append(os.path.join(root, each))
  103. break
  104. # check whether get two files
  105. if len(self.files) != 2:
  106. print("There are {} files in the root_dir".format(len(self.files)))
  107. raise ValueError
  108. # begin to read data
  109. self.word_num = 0
  110. self.maxlen = 0
  111. self.minlen = float("inf")
  112. self.maxlen = float("-inf")
  113. self.Pos = []
  114. self.Neg = []
  115. for filename in self.files:
  116. f = codecs.open(filename, 'r')
  117. ff = f.read()
  118. file_object = codecs.open(filename, 'w', 'utf-8')
  119. file_object.write(ff)
  120. self.read_data(filename)
  121. self.PosNeg = self.Pos + self.Neg
  122. self.text2vec(maxlen=maxlen)
  123. self.split_dataset(split=split)
  124. def read_data(self, filePath):
  125. """
  126. read text into memory
  127. input:
  128. filePath: the path where the data is stored in
  129. """
  130. with open(filePath, 'r') as f:
  131. for sentence in f.readlines():
  132. sentence = sentence.replace('\n', '')\
  133. .replace('"', '')\
  134. .replace('\'', '')\
  135. .replace('.', '')\
  136. .replace(',', '')\
  137. .replace('[', '')\
  138. .replace(']', '')\
  139. .replace('(', '')\
  140. .replace(')', '')\
  141. .replace(':', '')\
  142. .replace('--', '')\
  143. .replace('-', '')\
  144. .replace('\\', '')\
  145. .replace('0', '')\
  146. .replace('1', '')\
  147. .replace('2', '')\
  148. .replace('3', '')\
  149. .replace('4', '')\
  150. .replace('5', '')\
  151. .replace('6', '')\
  152. .replace('7', '')\
  153. .replace('8', '')\
  154. .replace('9', '')\
  155. .replace('`', '')\
  156. .replace('=', '')\
  157. .replace('$', '')\
  158. .replace('/', '')\
  159. .replace('*', '')\
  160. .replace(';', '')\
  161. .replace('<b>', '')\
  162. .replace('%', '')
  163. sentence = sentence.split(' ')
  164. sentence = list(filter(lambda x: x, sentence))
  165. if sentence:
  166. self.word_num += len(sentence)
  167. self.maxlen = self.maxlen if self.maxlen >= len(sentence) else len(sentence)
  168. self.minlen = self.minlen if self.minlen <= len(sentence) else len(sentence)
  169. if 'pos' in filePath:
  170. self.Pos.append([sentence, self.feelMap['pos']])
  171. else:
  172. self.Neg.append([sentence, self.feelMap['neg']])
  173. def text2vec(self, maxlen):
  174. """
  175. convert the sentence into a vector in an int type
  176. input:
  177. maxlen: max length of the sentence
  178. """
  179. # Vocab = {word : index}
  180. self.Vocab = dict()
  181. for SentenceLabel in self.Pos+self.Neg:
  182. vector = [0]*maxlen
  183. for index, word in enumerate(SentenceLabel[0]):
  184. if index >= maxlen:
  185. break
  186. if word not in self.Vocab.keys():
  187. self.Vocab[word] = len(self.Vocab)
  188. vector[index] = len(self.Vocab) - 1
  189. else:
  190. vector[index] = self.Vocab[word]
  191. SentenceLabel[0] = vector
  192. self.doConvert = True
  193. def split_dataset(self, split):
  194. """
  195. split the dataset into training set and test set
  196. input:
  197. split: the ratio of training set to test set
  198. rank: logic order
  199. size: device num
  200. """
  201. trunk_pos_size = math.ceil((1-split)*len(self.Pos))
  202. trunk_neg_size = math.ceil((1-split)*len(self.Neg))
  203. trunk_num = int(1/(1-split))
  204. pos_temp = list()
  205. neg_temp = list()
  206. for index in range(trunk_num):
  207. pos_temp.append(self.Pos[index*trunk_pos_size:(index+1)*trunk_pos_size])
  208. neg_temp.append(self.Neg[index*trunk_neg_size:(index+1)*trunk_neg_size])
  209. self.test = pos_temp.pop(2)+neg_temp.pop(2)
  210. self.train = [i for item in pos_temp+neg_temp for i in item]
  211. random.shuffle(self.train)
  212. class Subjectivity(DataProcessor):
  213. """
  214. preprocess Subjectivity dataset
  215. """
  216. def __init__(self, root_dir, maxlen, split):
  217. self.path = root_dir
  218. self.feelMap = {
  219. 'neg': 0,
  220. 'pos': 1
  221. }
  222. self.files = []
  223. self.doConvert = False
  224. mypath = Path(self.path)
  225. if not mypath.exists() or not mypath.is_dir():
  226. print("please check the root_dir!")
  227. raise ValueError
  228. # walk through the root_dir
  229. for root, _, filename in os.walk(self.path):
  230. for each in filename:
  231. self.files.append(os.path.join(root, each))
  232. break
  233. # begin to read data
  234. self.word_num = 0
  235. self.maxlen = 0
  236. self.minlen = float("inf")
  237. self.maxlen = float("-inf")
  238. self.Pos = []
  239. self.Neg = []
  240. for filename in self.files:
  241. self.read_data(filename)
  242. self.PosNeg = self.Pos + self.Neg
  243. self.text2vec(maxlen=maxlen)
  244. self.split_dataset(split=split)
  245. def read_data(self, filePath):
  246. """
  247. read text into memory
  248. input:
  249. filePath: the path where the data is stored in
  250. """
  251. with open(filePath, 'r', encoding="ISO-8859-1") as f:
  252. for sentence in f.readlines():
  253. sentence = sentence.replace('\n', '')\
  254. .replace('"', '')\
  255. .replace('\'', '')\
  256. .replace('.', '')\
  257. .replace(',', '')\
  258. .replace('[', '')\
  259. .replace(']', '')\
  260. .replace('(', '')\
  261. .replace(')', '')\
  262. .replace(':', '')\
  263. .replace('--', '')\
  264. .replace('-', '')\
  265. .replace('\\', '')\
  266. .replace('0', '')\
  267. .replace('1', '')\
  268. .replace('2', '')\
  269. .replace('3', '')\
  270. .replace('4', '')\
  271. .replace('5', '')\
  272. .replace('6', '')\
  273. .replace('7', '')\
  274. .replace('8', '')\
  275. .replace('9', '')\
  276. .replace('`', '')\
  277. .replace('=', '')\
  278. .replace('$', '')\
  279. .replace('/', '')\
  280. .replace('*', '')\
  281. .replace(';', '')\
  282. .replace('<b>', '')\
  283. .replace('%', '')
  284. sentence = sentence.split(' ')
  285. sentence = list(filter(lambda x: x, sentence))
  286. if sentence:
  287. self.word_num += len(sentence)
  288. self.maxlen = self.maxlen if self.maxlen >= len(sentence) else len(sentence)
  289. self.minlen = self.minlen if self.minlen <= len(sentence) else len(sentence)
  290. if 'quote' in filePath:
  291. self.Pos.append([sentence, self.feelMap['pos']])
  292. elif 'plot' in filePath:
  293. self.Neg.append([sentence, self.feelMap['neg']])
  294. def text2vec(self, maxlen):
  295. """
  296. convert the sentence into a vector in an int type
  297. input:
  298. maxlen: max length of the sentence
  299. """
  300. # Vocab = {word : index}
  301. self.Vocab = dict()
  302. for SentenceLabel in self.Pos+self.Neg:
  303. vector = [0]*maxlen
  304. for index, word in enumerate(SentenceLabel[0]):
  305. if index >= maxlen:
  306. break
  307. if word not in self.Vocab.keys():
  308. self.Vocab[word] = len(self.Vocab)
  309. vector[index] = len(self.Vocab) - 1
  310. else:
  311. vector[index] = self.Vocab[word]
  312. SentenceLabel[0] = vector
  313. self.doConvert = True
  314. def split_dataset(self, split):
  315. """
  316. split the dataset into training set and test set
  317. input:
  318. split: the ratio of training set to test set
  319. rank: logic order
  320. size: device num
  321. """
  322. trunk_pos_size = math.ceil((1-split)*len(self.Pos))
  323. trunk_neg_size = math.ceil((1-split)*len(self.Neg))
  324. trunk_num = int(1/(1-split))
  325. pos_temp = list()
  326. neg_temp = list()
  327. for index in range(trunk_num):
  328. pos_temp.append(self.Pos[index*trunk_pos_size:(index+1)*trunk_pos_size])
  329. neg_temp.append(self.Neg[index*trunk_neg_size:(index+1)*trunk_neg_size])
  330. self.test = pos_temp.pop(2)+neg_temp.pop(2)
  331. self.train = [i for item in pos_temp+neg_temp for i in item]
  332. random.shuffle(self.train)
  333. class SST2(DataProcessor):
  334. """
  335. preprocess SST2 dataset
  336. """
  337. def __init__(self, root_dir, maxlen, split):
  338. self.path = root_dir
  339. self.files = []
  340. self.train = []
  341. self.test = []
  342. self.doConvert = False
  343. mypath = Path(self.path)
  344. if not mypath.exists() or not mypath.is_dir():
  345. print("please check the root_dir!")
  346. raise ValueError
  347. # walk through the root_dir
  348. for root, _, filename in os.walk(self.path):
  349. for each in filename:
  350. self.files.append(os.path.join(root, each))
  351. break
  352. # begin to read data
  353. self.word_num = 0
  354. self.maxlen = 0
  355. self.minlen = float("inf")
  356. self.maxlen = float("-inf")
  357. for filename in self.files:
  358. if 'train' in filename or 'dev' in filename:
  359. f = codecs.open(filename, 'r')
  360. ff = f.read()
  361. file_object = codecs.open(filename, 'w', 'utf-8')
  362. file_object.write(ff)
  363. self.read_data(filename)
  364. self.text2vec(maxlen=maxlen)
  365. self.split_dataset(split=split)
  366. def read_data(self, filePath):
  367. """
  368. read text into memory
  369. input:
  370. filePath: the path where the data is stored in
  371. """
  372. df = pd.read_csv(filePath, delimiter='\t')
  373. for sentence, label in zip(df['sentence'], df['label']):
  374. sentence = sentence.replace('\n', '')\
  375. .replace('"', '')\
  376. .replace('\'', '')\
  377. .replace('.', '')\
  378. .replace(',', '')\
  379. .replace('[', '')\
  380. .replace(']', '')\
  381. .replace('(', '')\
  382. .replace(')', '')\
  383. .replace(':', '')\
  384. .replace('--', '')\
  385. .replace('-', '')\
  386. .replace('\\', '')\
  387. .replace('0', '')\
  388. .replace('1', '')\
  389. .replace('2', '')\
  390. .replace('3', '')\
  391. .replace('4', '')\
  392. .replace('5', '')\
  393. .replace('6', '')\
  394. .replace('7', '')\
  395. .replace('8', '')\
  396. .replace('9', '')\
  397. .replace('`', '')\
  398. .replace('=', '')\
  399. .replace('$', '')\
  400. .replace('/', '')\
  401. .replace('*', '')\
  402. .replace(';', '')\
  403. .replace('<b>', '')\
  404. .replace('%', '')
  405. sentence = sentence.split(' ')
  406. sentence = list(filter(lambda x: x, sentence))
  407. if sentence:
  408. self.word_num += len(sentence)
  409. self.maxlen = self.maxlen if self.maxlen >= len(sentence) else len(sentence)
  410. self.minlen = self.minlen if self.minlen <= len(sentence) else len(sentence)
  411. if 'train' in filePath:
  412. self.train.append([sentence, label])
  413. elif 'dev' in filePath:
  414. self.test.append([sentence, label])
  415. def text2vec(self, maxlen):
  416. """
  417. convert the sentence into a vector in an int type
  418. input:
  419. maxlen: max length of the sentence
  420. """
  421. # Vocab = {word : index}
  422. self.Vocab = dict()
  423. for SentenceLabel in self.train+self.test:
  424. vector = [0]*maxlen
  425. for index, word in enumerate(SentenceLabel[0]):
  426. if index >= maxlen:
  427. break
  428. if word not in self.Vocab.keys():
  429. self.Vocab[word] = len(self.Vocab)
  430. vector[index] = len(self.Vocab) - 1
  431. else:
  432. vector[index] = self.Vocab[word]
  433. SentenceLabel[0] = vector
  434. self.doConvert = True
  435. def split_dataset(self, split):
  436. """
  437. split the dataset into training set and test set
  438. input:
  439. split: the ratio of training set to test set
  440. rank: logic order
  441. size: device num
  442. """
  443. random.shuffle(self.train)