You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

offload.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Offload Support.
  16. """
  17. import json
  18. import numpy as np
  19. import mindspore.common.dtype as mstype
  20. from mindspore.common.tensor import Tensor
  21. import mindspore.nn as nn
  22. import mindspore.ops.composite as C
  23. from mindspore.ops import operations as P
  24. def check_concat_zip_dataset(dataset):
  25. """
  26. Check if dataset is concatenated or zipped.
  27. """
  28. while dataset:
  29. if len(dataset.children) > 1:
  30. return True
  31. if dataset.children:
  32. dataset = dataset.children[0]
  33. continue
  34. dataset = dataset.children
  35. return False
  36. def check_map_offload(dataset):
  37. """
  38. Check if offload flag is set in data pipeline map ops.
  39. """
  40. offload_check = False
  41. concat_zip_check = check_concat_zip_dataset(dataset)
  42. while dataset:
  43. if hasattr(dataset, 'offload'):
  44. if dataset.offload is True:
  45. offload_check = True
  46. break
  47. if dataset.children:
  48. dataset = dataset.children[0]
  49. else:
  50. dataset = []
  51. if offload_check and concat_zip_check:
  52. raise RuntimeError("Offload module currently does not support concatenated or zipped datasets.")
  53. return offload_check
  54. def apply_offload_iterators(data, offload_model):
  55. """
  56. Apply offload for non sink mode pipeline.
  57. """
  58. if len(data) != 2:
  59. # A temporary solution to ensure there are two columns in dataset.
  60. raise RuntimeError("Offload can currently only use datasets with two columns.")
  61. if isinstance(data[0], Tensor) is True:
  62. data[0] = offload_model(data[0])
  63. else:
  64. data[0] = Tensor(data[0], dtype=mstype.float32)
  65. data[0] = offload_model(data[0]).asnumpy()
  66. return data
  67. class ApplyPreTransform(nn.Cell):
  68. """
  69. Concatenates offload model with network.
  70. """
  71. def __init__(self, transform, model):
  72. super(ApplyPreTransform, self).__init__(auto_prefix=False, flags=model.get_flags())
  73. self.transform = transform
  74. self.model = model
  75. def construct(self, x, label):
  76. x = self.transform(x)
  77. x = self.model(x, label)
  78. return x
  79. class IdentityCell(nn.Cell):
  80. """
  81. Applies identity transform on given input tensors.
  82. """
  83. def __init__(self):
  84. super(IdentityCell, self).__init__()
  85. self.identity = P.Identity()
  86. def construct(self, x):
  87. return self.identity(x)
  88. class RandomHorizontalFlip(nn.Cell):
  89. """
  90. Applies Random Horizontal Flip transform on given input tensors.
  91. """
  92. def __init__(self, prob):
  93. super(RandomHorizontalFlip, self).__init__()
  94. self.prob = Tensor(prob, dtype=mstype.float32)
  95. self.cast = P.Cast()
  96. self.shape = P.Shape()
  97. self.uniformReal = P.UniformReal()
  98. self.reshape = P.Reshape()
  99. self.h_flip = P.ReverseV2(axis=[2])
  100. self.mul = P.Mul()
  101. def construct(self, x):
  102. x = self.cast(x, mstype.float32)
  103. bs, h, w, c = self.shape(x)
  104. flip_rand_factor = self.uniformReal((bs, 1))
  105. flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32)
  106. flip_rand_factor = self.reshape(C.repeat_elements(flip_rand_factor, rep=(h*w*c)), (bs, h, w, c))
  107. x_flip = self.h_flip(x)
  108. x = self.mul(x_flip, flip_rand_factor) + self.mul((1 - flip_rand_factor), x)
  109. return x
  110. class RandomVerticalFlip(nn.Cell):
  111. """
  112. Applies Random Vertical Flip transform on given input tensors.
  113. """
  114. def __init__(self, prob):
  115. super(RandomVerticalFlip, self).__init__()
  116. self.prob = Tensor(prob, dtype=mstype.float32)
  117. self.cast = P.Cast()
  118. self.shape = P.Shape()
  119. self.uniformReal = P.UniformReal()
  120. self.reshape = P.Reshape()
  121. self.h_flip = P.ReverseV2(axis=[1])
  122. self.mul = P.Mul()
  123. def construct(self, x):
  124. x = self.cast(x, mstype.float32)
  125. bs, h, w, c = self.shape(x)
  126. flip_rand_factor = self.uniformReal((bs, 1))
  127. flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32)
  128. flip_rand_factor = self.reshape(C.repeat_elements(flip_rand_factor, rep=(h*w*c)), (bs, h, w, c))
  129. x_flip = self.h_flip(x)
  130. x = self.mul(x_flip, flip_rand_factor) + self.mul((1 - flip_rand_factor), x)
  131. return x
  132. class RandomColorAdjust(nn.Cell):
  133. """
  134. Applies Random Color Adjust transform on given input tensors.
  135. """
  136. def __init__(self, brightness, saturation):
  137. super(RandomColorAdjust, self).__init__()
  138. if isinstance(brightness, (list, tuple)):
  139. self.br_min = brightness[0]
  140. self.br_max = brightness[1]
  141. else:
  142. self.br_min = max(0, 1 - brightness)
  143. self.br_max = 1 + brightness
  144. if isinstance(saturation, (list, tuple)):
  145. self.sa_min = saturation[0]
  146. self.sa_max = saturation[1]
  147. else:
  148. self.sa_min = max(0, 1 - saturation)
  149. self.sa_max = 1 + saturation
  150. self.cast = P.Cast()
  151. self.shape = P.Shape()
  152. self.uniformReal = P.UniformReal()
  153. self.reshape = P.Reshape()
  154. self.unstack = P.Unstack(axis=-1)
  155. self.expand_dims = P.ExpandDims()
  156. self.mul = P.Mul()
  157. def construct(self, x):
  158. x = self.cast(x, mstype.float32)
  159. bs, h, w, c = self.shape(x)
  160. br_rand_factor = self.br_min + (self.br_max - self.br_min)*self.uniformReal((bs, 1))
  161. br_rand_factor = self.reshape(C.repeat_elements(br_rand_factor, rep=(h*w*c)), (bs, h, w, c))
  162. sa_rand_factor = self.sa_min + (self.sa_max - self.sa_min)*self.uniformReal((bs, 1))
  163. sa_rand_factor = self.reshape(C.repeat_elements(sa_rand_factor, rep=(h*w*c)), (bs, h, w, c))
  164. r, g, b = self.unstack(x)
  165. x_gray = C.repeat_elements(self.expand_dims((0.2989 * r + 0.587 * g + 0.114 * b), -1), rep=c, axis=-1)
  166. x = self.mul(x, br_rand_factor)
  167. x = C.clip_by_value(x, 0.0, 255.0)
  168. x = self.mul(x, sa_rand_factor) + self.mul((1 - sa_rand_factor), x_gray)
  169. x = C.clip_by_value(x, 0.0, 255.0)
  170. return x
  171. class RandomSharpness(nn.Cell):
  172. """
  173. Applies Random Sharpness transform on given input tensors.
  174. """
  175. def __init__(self, degrees):
  176. super(RandomSharpness, self).__init__()
  177. if isinstance(degrees, (list, tuple)):
  178. self.degree_min = degrees[0]
  179. self.degree_max = degrees[1]
  180. else:
  181. self.degree_min = max(0, 1 - degrees)
  182. self.degree_max = 1 + degrees
  183. self.cast = P.Cast()
  184. self.shape = P.Shape()
  185. self.uniformReal = P.UniformReal()
  186. self.reshape = P.Reshape()
  187. self.expand_dims = P.ExpandDims()
  188. self.mul = P.Mul()
  189. self.transpose = P.Transpose()
  190. self.weight = np.array([[1, 1, 1], [1, 5, 1], [1, 1, 1]])/13.0
  191. self.weight = np.repeat(self.weight[np.newaxis, :, :], 3, axis=0)
  192. self.weight = np.repeat(self.weight[np.newaxis, :, :], 3, axis=0)
  193. self.weight = Tensor(self.weight, mstype.float32)
  194. self.filter = P.Conv2D(out_channel=3, kernel_size=(3, 3), pad_mode='same')
  195. def construct(self, x):
  196. x = self.cast(x, mstype.float32)
  197. bs, h, w, c = self.shape(x)
  198. degree_rand_factor = self.degree_min + (self.degree_max - self.degree_min)*self.uniformReal((bs, 1))
  199. degree_rand_factor = self.reshape(C.repeat_elements(degree_rand_factor, rep=(h*w*c)), (bs, h, w, c))
  200. x_sharp = self.filter(self.transpose(x, (0, 3, 1, 2)), self.weight)
  201. x_sharp = self.transpose(x_sharp, (0, 2, 3, 1))
  202. x = self.mul(x, degree_rand_factor) + self.mul((1 - degree_rand_factor), x_sharp)
  203. x = C.clip_by_value(x, 0.0, 255.0)
  204. return x
  205. class Rescale(nn.Cell):
  206. """
  207. Applies Rescale transform on given input tensors.
  208. """
  209. def __init__(self, rescale, shift):
  210. super(Rescale, self).__init__()
  211. self.rescale = Tensor(rescale, dtype=mstype.float32)
  212. self.shift = Tensor(shift, dtype=mstype.float32)
  213. self.cast = P.Cast()
  214. self.mul = P.Mul()
  215. def construct(self, x):
  216. x = self.cast(x, mstype.float32)
  217. x = x * self.rescale + self.shift
  218. return x
  219. class HwcToChw(nn.Cell):
  220. """
  221. Applies Channel Swap transform on given input tensors.
  222. """
  223. def __init__(self):
  224. super(HwcToChw, self).__init__()
  225. self.trans = P.Transpose()
  226. def construct(self, x):
  227. return self.trans(x, (0, 3, 1, 2))
  228. class Normalize(nn.Cell):
  229. """
  230. Applies Normalize transform on given input tensors.
  231. """
  232. def __init__(self, mean, std):
  233. super(Normalize, self).__init__()
  234. self.mean = Tensor(mean, mstype.float32)
  235. self.std = Tensor(std, mstype.float32)
  236. self.sub = P.Sub()
  237. self.div = P.Div()
  238. self.cast = P.Cast()
  239. def construct(self, x):
  240. x = self.cast(x, mstype.float32)
  241. x = self.sub(x, self.mean)
  242. x = self.div(x, self.std)
  243. return x
  244. class OffloadModel():
  245. def __init__(self, func, args_names=None):
  246. self.func = func
  247. self.args_names = args_names
  248. # Dictionary connecting operation name to model
  249. op_to_model = {
  250. "HWC2CHW": OffloadModel(HwcToChw),
  251. "HwcToChw": OffloadModel(HwcToChw),
  252. "Normalize": OffloadModel(Normalize, ["std", "mean"]),
  253. "RandomColorAdjust": OffloadModel(RandomColorAdjust, ["brightness", "saturation"]),
  254. "RandomHorizontalFlip": OffloadModel(RandomHorizontalFlip, ["prob"]),
  255. "RandomSharpness": OffloadModel(RandomSharpness, ["degrees"]),
  256. "RandomVerticalFlip": OffloadModel(RandomVerticalFlip, ["prob"]),
  257. "Rescale": OffloadModel(Rescale, ["rescale", "shift"])
  258. }
  259. class GetModelFromJson2Col(nn.Cell):
  260. """
  261. Generates offload ME model from offload JSON file for a single map op.
  262. """
  263. def __init__(self, json_offload):
  264. super(GetModelFromJson2Col, self).__init__()
  265. self.me_ops = []
  266. if json_offload is not None:
  267. offload_ops = json_offload["operations"]
  268. for op in offload_ops:
  269. name = op["tensor_op_name"]
  270. args = op["tensor_op_params"]
  271. op_model = op_to_model[name]
  272. op_model_inputs = []
  273. if op_model.args_names is not None:
  274. for arg_key in op_model.args_names:
  275. op_model_inputs.append(args[arg_key])
  276. self.me_ops.append(op_model.func(*op_model_inputs))
  277. else:
  278. raise RuntimeError("Offload hardware accelarator cannot be applied for this pipeline.")
  279. self.cell = nn.SequentialCell(self.me_ops)
  280. def construct(self, x):
  281. return self.cell(x)
  282. class GetOffloadModel(nn.Cell):
  283. """
  284. Generates offload ME model.
  285. """
  286. def __init__(self, dataset_consumer):
  287. super(GetOffloadModel, self).__init__()
  288. self.transform_list = []
  289. json_offload = json.loads(dataset_consumer.GetOffload())
  290. if json_offload is not None:
  291. for node in json_offload:
  292. if node["op_type"] == 'Map':
  293. self.transform_list.append(GetModelFromJson2Col(node))
  294. self.transform_list.reverse()
  295. def construct(self, x):
  296. for transform in self.transform_list:
  297. x = transform(x)
  298. return x