You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

classify_by_textcnn.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. """
  2. /**
  3. * Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. * =============================================================
  17. */
  18. """
  19. import json
  20. import re
  21. import six
  22. import numpy as np
  23. from typing import Tuple
  24. # import requests # 在 nfs 没有挂载 时使用 url 访问
  25. import sys
  26. sys.path.append("../../")
  27. import oneflow as flow
  28. import oneflow.typing as tp
  29. BATCH_SIZE = 16
  30. class TextCNN:
  31. def __init__(self, emb_sz, emb_dim, ksize_list, n_filters_list, n_classes, dropout):
  32. self.initializer = flow.random_normal_initializer(stddev=0.1)
  33. self.emb_sz = emb_sz
  34. self.emb_dim = emb_dim
  35. self.ksize_list = ksize_list
  36. self.n_filters_list = n_filters_list
  37. self.n_classes = n_classes
  38. self.dropout = dropout
  39. self.total_n_filters = sum(self.n_filters_list)
  40. def get_logits(self, inputs, is_train):
  41. emb_weight = flow.get_variable(
  42. 'embedding-weight',
  43. shape=(self.emb_sz, self.emb_dim),
  44. dtype=flow.float32,
  45. trainable=is_train,
  46. reuse=False,
  47. initializer=self.initializer,
  48. )
  49. data = flow.gather(emb_weight, inputs, axis=0)
  50. data = flow.transpose(data, [0, 2, 1]) # BLH -> BHL
  51. data = flow.reshape(data, list(data.shape) + [1])
  52. seq_length = data.shape[2]
  53. pooled_list = []
  54. for i in range(len(self.n_filters_list)):
  55. ksz = self.ksize_list[i]
  56. n_filters = self.n_filters_list[i]
  57. conv = flow.layers.conv2d(data, n_filters, [ksz, 1], data_format="NCHW",
  58. kernel_initializer=self.initializer, name='conv-{}'.format(i)) # NCHW
  59. # conv = flow.layers.layer_norm(conv, name='ln-{}'.format(i))
  60. conv = flow.nn.relu(conv)
  61. pooled = flow.nn.max_pool2d(conv, [seq_length - ksz + 1, 1], strides=1, padding='VALID', data_format="NCHW")
  62. pooled_list.append(pooled)
  63. pooled = flow.concat(pooled_list, 3)
  64. pooled = flow.reshape(pooled, [-1, self.total_n_filters])
  65. if is_train:
  66. pooled = flow.nn.dropout(pooled, rate=self.dropout)
  67. pooled = flow.layers.dense(pooled, self.total_n_filters, use_bias=True,
  68. kernel_initializer=self.initializer, name='dense-1')
  69. pooled = flow.nn.relu(pooled)
  70. logits = flow.layers.dense(pooled, self.n_classes, use_bias=True,
  71. kernel_initializer=self.initializer, name='dense-2')
  72. return logits
  73. def get_eval_config():
  74. config = flow.function_config()
  75. config.default_data_type(flow.float)
  76. return config
  77. def pad_sequences(sequences, maxlen=None, dtype='int32',
  78. padding='pre', truncating='pre', value=0.):
  79. """Pads sequences to the same length.
  80. This function transforms a list of
  81. `num_samples` sequences (lists of integers)
  82. into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
  83. `num_timesteps` is either the `maxlen` argument if provided,
  84. or the length of the longest sequence otherwise.
  85. Sequences that are shorter than `num_timesteps`
  86. are padded with `value` at the beginning or the end
  87. if padding='post.
  88. Sequences longer than `num_timesteps` are truncated
  89. so that they fit the desired length.
  90. The position where padding or truncation happens is determined by
  91. the arguments `padding` and `truncating`, respectively.
  92. Pre-padding is the default.
  93. # Arguments
  94. sequences: List of lists, where each element is a sequence.
  95. maxlen: Int, maximum length of all sequences.
  96. dtype: Type of the output sequences.
  97. To pad sequences with variable length strings, you can use `object`.
  98. padding: String, 'pre' or 'post':
  99. pad either before or after each sequence.
  100. truncating: String, 'pre' or 'post':
  101. remove values from sequences larger than
  102. `maxlen`, either at the beginning or at the end of the sequences.
  103. value: Float or String, padding value.
  104. # Returns
  105. x: Numpy array with shape `(len(sequences), maxlen)`
  106. # Raises
  107. ValueError: In case of invalid values for `truncating` or `padding`,
  108. or in case of invalid shape for a `sequences` entry.
  109. """
  110. if not hasattr(sequences, '__len__'):
  111. raise ValueError('`sequences` must be iterable.')
  112. num_samples = len(sequences)
  113. lengths = []
  114. sample_shape = ()
  115. flag = True
  116. # take the sample shape from the first non empty sequence
  117. # checking for consistency in the main loop below.
  118. for x in sequences:
  119. try:
  120. lengths.append(len(x))
  121. if flag and len(x):
  122. sample_shape = np.asarray(x).shape[1:]
  123. flag = False
  124. except TypeError:
  125. raise ValueError('`sequences` must be a list of iterables. '
  126. 'Found non-iterable: ' + str(x))
  127. if maxlen is None:
  128. maxlen = np.max(lengths)
  129. is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
  130. if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
  131. raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
  132. "You should set `dtype=object` for variable length strings."
  133. .format(dtype, type(value)))
  134. x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
  135. for idx, s in enumerate(sequences):
  136. if not len(s):
  137. continue # empty list/array was found
  138. if truncating == 'pre':
  139. trunc = s[-maxlen:]
  140. elif truncating == 'post':
  141. trunc = s[:maxlen]
  142. else:
  143. raise ValueError('Truncating type "%s" '
  144. 'not understood' % truncating)
  145. # check `trunc` has expected shape
  146. trunc = np.asarray(trunc, dtype=dtype)
  147. if trunc.shape[1:] != sample_shape:
  148. raise ValueError('Shape of sample %s of sequence at position %s '
  149. 'is different from expected shape %s' %
  150. (trunc.shape[1:], idx, sample_shape))
  151. if padding == 'post':
  152. x[idx, :len(trunc)] = trunc
  153. elif padding == 'pre':
  154. x[idx, -len(trunc):] = trunc
  155. else:
  156. raise ValueError('Padding type "%s" not understood' % padding)
  157. return x
  158. @flow.global_function('predict', get_eval_config())
  159. def predict_job(text: tp.Numpy.Placeholder((BATCH_SIZE, 150), dtype=flow.int32),
  160. ) -> Tuple[tp.Numpy, tp.Numpy]:
  161. with flow.scope.placement("gpu", "0:0"):
  162. model = TextCNN(50000, 100, ksize_list=[2, 3, 4, 5], n_filters_list=[100] * 4, n_classes=2, dropout=0.5)
  163. logits = model.get_logits(text, is_train=False)
  164. logits = flow.nn.softmax(logits)
  165. label = flow.math.argmax(logits)
  166. return label, logits
  167. class TextCNNClassifier:
  168. def __init__(self):
  169. model_load_dir = "../of_model/textcnn_imdb_of_best_model/"
  170. word_index_dir = "../of_model/imdb_word_index/imdb_word_index.json"
  171. checkpoint = flow.train.CheckPoint()
  172. checkpoint.init()
  173. checkpoint.load(model_load_dir)
  174. with open(word_index_dir) as f:
  175. word_index = json.load(f)
  176. word_index = {k: (v + 2) for k, v in word_index.items()}
  177. word_index["<PAD>"] = 0
  178. word_index["<START>"] = 1
  179. word_index["<UNK>"] = 2
  180. self.word_index = word_index
  181. def inference(self, text_path_list, id_list, label_list):
  182. print("infer")
  183. classifications = []
  184. batch_text = []
  185. for i, text_path in enumerate(text_path_list):
  186. text = open('/nfs/' + text_path, "r").read()
  187. """
  188. # 在 nfs 没有挂载 时使用 url 访问 MinIO 进行测试
  189. url = "http://10.5.29.100:9000/" + text_path
  190. print(url)
  191. text = requests.get(url).text # .encode('utf-8').decode('utf-8')
  192. """
  193. text = re.sub("[^a-zA-Z']", " ", text)
  194. text = list(map(lambda x: x.lower(), text.split()))
  195. text.insert(0, "<START>")
  196. batch_text.append(
  197. list(map(lambda x: self.word_index[x] if x in self.word_index else self.word_index["<UNK>"], text))
  198. )
  199. if i % BATCH_SIZE == BATCH_SIZE - 1:
  200. text = pad_sequences(batch_text, value=self.word_index["<PAD>"], padding='post', maxlen=150)
  201. text = np.array(text, dtype=np.int32)
  202. label, logits = predict_job(text)
  203. label = label.tolist()
  204. logits = logits.tolist()
  205. for k in range(BATCH_SIZE):
  206. classifications.append({
  207. 'id': id_list[i - BATCH_SIZE + 1 + k],
  208. 'annotation': json.dumps(
  209. [{'category_id': label_list[label[k]], 'score': round(logits[k][label[k]], 4)}])
  210. })
  211. batch_text = []
  212. return classifications

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)