You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yelpLoader.py 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import ast
  2. from fastNLP import DataSet, Instance, Vocabulary
  3. from fastNLP.core.vocabulary import VocabularyOption
  4. from fastNLP.io import JsonLoader
  5. from fastNLP.io.base_loader import DataInfo
  6. from fastNLP.io.embed_loader import EmbeddingOption
  7. from fastNLP.io.file_reader import _read_json
  8. from typing import Union, Dict
  9. from reproduction.Star_transformer.datasets import EmbedLoader
  10. from reproduction.utils import check_dataloader_paths
  11. class yelpLoader(JsonLoader):
  12. """
  13. 读取Yelp数据集, DataSet包含fields:
  14. review_id: str, 22 character unique review id
  15. user_id: str, 22 character unique user id
  16. business_id: str, 22 character business id
  17. useful: int, number of useful votes received
  18. funny: int, number of funny votes received
  19. cool: int, number of cool votes received
  20. date: str, date formatted YYYY-MM-DD
  21. words: list(str), 需要分类的文本
  22. target: str, 文本的标签
  23. 数据来源: https://www.yelp.com/dataset/download
  24. :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
  25. """
  26. def __init__(self, fine_grained=False):
  27. super(yelpLoader, self).__init__()
  28. tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral',
  29. '4.0': 'positive', '5.0': 'very positive'}
  30. if not fine_grained:
  31. tag_v['1.0'] = tag_v['2.0']
  32. tag_v['5.0'] = tag_v['4.0']
  33. self.fine_grained = fine_grained
  34. self.tag_v = tag_v
  35. def _load(self, path):
  36. ds = DataSet()
  37. for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
  38. d = ast.literal_eval(d)
  39. d["words"] = d.pop("text").split()
  40. d["target"] = self.tag_v[str(d.pop("stars"))]
  41. ds.append(Instance(**d))
  42. return ds
  43. def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None,
  44. embed_opt: EmbeddingOption = None):
  45. paths = check_dataloader_paths(paths)
  46. datasets = {}
  47. info = DataInfo()
  48. vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt)
  49. for name, path in paths.items():
  50. dataset = self.load(path)
  51. datasets[name] = dataset
  52. vocab.from_dataset(dataset, field_name="words")
  53. info.vocabs = vocab
  54. info.datasets = datasets
  55. if embed_opt is not None:
  56. embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab)
  57. info.embeddings['words'] = embed
  58. return info