You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

classify_by_textcnn.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # !/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. """
  4. Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  5. Licensed under the Apache License, Version 2.0 (the "License");
  6. you may not use this file except in compliance with the License.
  7. You may obtain a copy of the License at
  8. http://www.apache.org/licenses/LICENSE-2.0
  9. Unless required by applicable law or agreed to in writing, software
  10. distributed under the License is distributed on an "AS IS" BASIS,
  11. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. See the License for the specific language governing permissions and
  13. limitations under the License.
  14. =============================================================
  15. """
  16. import json
  17. import os
  18. import re
  19. import six
  20. import numpy as np
  21. from typing import Tuple
  22. # import requests # 在 nfs 没有挂载 时使用 url 访问
  23. import oneflow as flow
  24. import oneflow.typing as tp
  25. import logging
  26. logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
  27. level=logging.DEBUG)
  28. BATCH_SIZE = 16
  29. current_dir = os.path.dirname(os.path.abspath(__file__))
  30. label_to_name_file = current_dir + os.sep + "label.names"
  31. label_2_name = []
  32. with open(label_to_name_file, 'r') as f:
  33. label_2_name = f.readlines()
  34. class TextCNN:
  35. def __init__(self, emb_sz, emb_dim, ksize_list, n_filters_list, n_classes, dropout):
  36. self.initializer = flow.random_normal_initializer(stddev=0.1)
  37. self.emb_sz = emb_sz
  38. self.emb_dim = emb_dim
  39. self.ksize_list = ksize_list
  40. self.n_filters_list = n_filters_list
  41. self.n_classes = n_classes
  42. self.dropout = dropout
  43. self.total_n_filters = sum(self.n_filters_list)
  44. def get_logits(self, inputs, is_train):
  45. emb_weight = flow.get_variable(
  46. 'embedding-weight',
  47. shape=(self.emb_sz, self.emb_dim),
  48. dtype=flow.float32,
  49. trainable=is_train,
  50. reuse=False,
  51. initializer=self.initializer,
  52. )
  53. data = flow.gather(emb_weight, inputs, axis=0)
  54. data = flow.transpose(data, [0, 2, 1]) # BLH -> BHL
  55. data = flow.reshape(data, list(data.shape) + [1])
  56. seq_length = data.shape[2]
  57. pooled_list = []
  58. for i in range(len(self.n_filters_list)):
  59. ksz = self.ksize_list[i]
  60. n_filters = self.n_filters_list[i]
  61. conv = flow.layers.conv2d(data, n_filters, [ksz, 1], data_format="NCHW",
  62. kernel_initializer=self.initializer, name='conv-{}'.format(i)) # NCHW
  63. # conv = flow.layers.layer_norm(conv, name='ln-{}'.format(i))
  64. conv = flow.nn.relu(conv)
  65. pooled = flow.nn.max_pool2d(conv, [seq_length - ksz + 1, 1], strides=1, padding='VALID', data_format="NCHW")
  66. pooled_list.append(pooled)
  67. pooled = flow.concat(pooled_list, 3)
  68. pooled = flow.reshape(pooled, [-1, self.total_n_filters])
  69. if is_train:
  70. pooled = flow.nn.dropout(pooled, rate=self.dropout)
  71. pooled = flow.layers.dense(pooled, self.total_n_filters, use_bias=True,
  72. kernel_initializer=self.initializer, name='dense-1')
  73. pooled = flow.nn.relu(pooled)
  74. logits = flow.layers.dense(pooled, self.n_classes, use_bias=True,
  75. kernel_initializer=self.initializer, name='dense-2')
  76. return logits
  77. def get_eval_config():
  78. config = flow.function_config()
  79. config.default_data_type(flow.float)
  80. return config
  81. def pad_sequences(sequences, maxlen=None, dtype='int32',
  82. padding='pre', truncating='pre', value=0.):
  83. """Pads sequences to the same length.
  84. This function transforms a list of
  85. `num_samples` sequences (lists of integers)
  86. into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
  87. `num_timesteps` is either the `maxlen` argument if provided,
  88. or the length of the longest sequence otherwise.
  89. Sequences that are shorter than `num_timesteps`
  90. are padded with `value` at the beginning or the end
  91. if padding='post.
  92. Sequences longer than `num_timesteps` are truncated
  93. so that they fit the desired length.
  94. The position where padding or truncation happens is determined by
  95. the arguments `padding` and `truncating`, respectively.
  96. Pre-padding is the default.
  97. # Arguments
  98. sequences: List of lists, where each element is a sequence.
  99. maxlen: Int, maximum length of all sequences.
  100. dtype: Type of the output sequences.
  101. To pad sequences with variable length strings, you can use `object`.
  102. padding: String, 'pre' or 'post':
  103. pad either before or after each sequence.
  104. truncating: String, 'pre' or 'post':
  105. remove values from sequences larger than
  106. `maxlen`, either at the beginning or at the end of the sequences.
  107. value: Float or String, padding value.
  108. # Returns
  109. x: Numpy array with shape `(len(sequences), maxlen)`
  110. # Raises
  111. ValueError: In case of invalid values for `truncating` or `padding`,
  112. or in case of invalid shape for a `sequences` entry.
  113. """
  114. if not hasattr(sequences, '__len__'):
  115. raise ValueError('`sequences` must be iterable.')
  116. num_samples = len(sequences)
  117. lengths = []
  118. sample_shape = ()
  119. flag = True
  120. # take the sample shape from the first non empty sequence
  121. # checking for consistency in the main loop below.
  122. for x in sequences:
  123. try:
  124. lengths.append(len(x))
  125. if flag and len(x):
  126. sample_shape = np.asarray(x).shape[1:]
  127. flag = False
  128. except TypeError:
  129. raise ValueError('`sequences` must be a list of iterables. '
  130. 'Found non-iterable: ' + str(x))
  131. if maxlen is None:
  132. maxlen = np.max(lengths)
  133. is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
  134. if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
  135. raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
  136. "You should set `dtype=object` for variable length strings."
  137. .format(dtype, type(value)))
  138. x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
  139. for idx, s in enumerate(sequences):
  140. if not len(s):
  141. continue # empty list/array was found
  142. if truncating == 'pre':
  143. trunc = s[-maxlen:]
  144. elif truncating == 'post':
  145. trunc = s[:maxlen]
  146. else:
  147. raise ValueError('Truncating type "%s" '
  148. 'not understood' % truncating)
  149. # check `trunc` has expected shape
  150. trunc = np.asarray(trunc, dtype=dtype)
  151. if trunc.shape[1:] != sample_shape:
  152. raise ValueError('Shape of sample %s of sequence at position %s '
  153. 'is different from expected shape %s' %
  154. (trunc.shape[1:], idx, sample_shape))
  155. if padding == 'post':
  156. x[idx, :len(trunc)] = trunc
  157. elif padding == 'pre':
  158. x[idx, -len(trunc):] = trunc
  159. else:
  160. raise ValueError('Padding type "%s" not understood' % padding)
  161. return x
  162. @flow.global_function('predict', get_eval_config())
  163. def predict_job(text: tp.Numpy.Placeholder((BATCH_SIZE, 150), dtype=flow.int32),
  164. ) -> Tuple[tp.Numpy, tp.Numpy]:
  165. with flow.scope.placement("gpu", "0:0"):
  166. model = TextCNN(50000, 100, ksize_list=[2, 3, 4, 5], n_filters_list=[100] * 4, n_classes=2, dropout=0.5)
  167. logits = model.get_logits(text, is_train=False)
  168. logits = flow.nn.softmax(logits)
  169. label = flow.math.argmax(logits)
  170. return label, logits
  171. class TextCNNClassifier:
  172. def __init__(self):
  173. model_load_dir = current_dir + os.sep + "model/textcnn_imdb_of_best_model/"
  174. word_index_dir = current_dir + os.sep + "model/imdb_word_index/imdb_word_index.json"
  175. checkpoint = flow.train.CheckPoint()
  176. checkpoint.init()
  177. checkpoint.load(model_load_dir)
  178. with open(word_index_dir) as f:
  179. word_index = json.load(f)
  180. word_index = {k: (v + 2) for k, v in word_index.items()}
  181. word_index["<PAD>"] = 0
  182. word_index["<START>"] = 1
  183. word_index["<UNK>"] = 2
  184. self.word_index = word_index
  185. def inference(self, text_path_list, id_list, label_list):
  186. logging.info("infer")
  187. logging.info(label_list)
  188. classifications = []
  189. batch_text = []
  190. for i, text_path in enumerate(text_path_list):
  191. text = open(text_path, "r").read()
  192. """
  193. # 在 nfs 没有挂载 时使用 url 访问 MinIO 进行测试
  194. url = "http://10.5.29.100:9000/" + text_path
  195. print(url)
  196. text = requests.get(url).text # .encode('utf-8').decode('utf-8')
  197. """
  198. text = re.sub("[^a-zA-Z']", " ", text)
  199. text = list(map(lambda x: x.lower(), text.split()))
  200. text.insert(0, "<START>")
  201. batch_text.append(
  202. list(map(lambda x: self.word_index[x] if x in self.word_index else self.word_index["<UNK>"], text))
  203. )
  204. if i % BATCH_SIZE == BATCH_SIZE - 1:
  205. text = pad_sequences(batch_text, value=self.word_index["<PAD>"], padding='post', maxlen=150)
  206. text = np.array(text, dtype=np.int32)
  207. label, logits = predict_job(text)
  208. label = label.tolist()
  209. logits = logits.tolist()
  210. for k in range(BATCH_SIZE):
  211. temp = {}
  212. temp['annotation'] = []
  213. temp['annotation'] = json.dumps(temp['annotation'])
  214. temp['id'] = id_list[i - BATCH_SIZE + 1 + k]
  215. if label[k] in label_list:
  216. temp['annotation'] = json.dumps([{'category_id': label_2_name[label[k]].rstrip('\n'), 'score': round(logits[k][label[k]], 4)}])
  217. classifications.append(temp)
  218. batch_text = []
  219. return classifications

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能