You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_helpers.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """dataset helpers api"""
  16. import argparse
  17. import os
  18. import numpy as np
  19. parser = argparse.ArgumentParser(description='textrcnn')
  20. parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
  21. parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
  22. parser.add_argument('--out_dir', type=str, help='the target dataset directory.', default='./data')
  23. args = parser.parse_args()
  24. def dataset_split(label):
  25. """dataset_split api"""
  26. # label can be 'pos' or 'neg'
  27. pos_samples = []
  28. pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label)
  29. pfhand = open(pos_file, encoding='utf-8')
  30. pos_samples += pfhand.readlines()
  31. pfhand.close()
  32. perm = np.random.permutation(len(pos_samples))
  33. # print(perm[0:int(len(pos_samples)*0.8)])
  34. perm_train = perm[0:int(len(pos_samples)*0.9)]
  35. perm_test = perm[int(len(pos_samples)*0.9):]
  36. pos_samples_train = []
  37. pos_samples_test = []
  38. for pt in perm_train:
  39. pos_samples_train.append(pos_samples[pt])
  40. for pt in perm_test:
  41. pos_samples_test.append(pos_samples[pt])
  42. f = open(os.path.join(args.out_dir, 'train', label), "w")
  43. f.write(''.join(pos_samples_train))
  44. f.close()
  45. f = open(os.path.join(args.out_dir, 'test', label), "w")
  46. f.write(''.join(pos_samples_test))
  47. f.close()
  48. if __name__ == '__main__':
  49. if args.task == "dataset_split":
  50. dataset_split('pos')
  51. dataset_split('neg')
  52. # search(args.q)