You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. # Unless required by applicable law or agreed to in writing, software
  7. # distributed under the License is distributed on an "AS IS" BASIS,
  8. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # See the License for the specific language governing permissions and
  10. # limitations under the License.
  11. # ============================================================================
  12. import mindspore.nn as nn
  13. from mindspore.ops import operations as P
  14. from mindspore.nn.loss.loss import _Loss
  15. from mindspore.ops import functional as F
  16. from mindspore.ops import composite as C
  17. from mindspore.context import ParallelMode
  18. from mindspore import context
  19. from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
  20. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  21. from src.config import params
  22. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  23. time_stamp_init = False
  24. time_stamp_first = 0
  25. grad_scale = C.MultitypeFuncGraph("grad_scale")
  26. _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
  27. reciprocal = P.Reciprocal()
  28. GRADIENT_CLIP_TYPE = params['GRADIENT_CLIP_TYPE']
  29. GRADIENT_CLIP_VALUE = params['GRADIENT_CLIP_VALUE']
  30. clip_grad = C.MultitypeFuncGraph("clip_grad")
  31. @clip_grad.register("Number", "Number", "Tensor")
  32. def _clip_grad(clip_type, clip_value, grad):
  33. """
  34. Clip gradients.
  35. Inputs:
  36. clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
  37. clip_value (float): Specifies how much to clip.
  38. grad (tuple[Tensor]): Gradients.
  39. Outputs:
  40. tuple[Tensor]: clipped gradients.
  41. """
  42. if clip_type not in (0, 1):
  43. return grad
  44. dt = F.dtype(grad)
  45. if clip_type == 0:
  46. new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
  47. F.cast(F.tuple_to_array((clip_value,)), dt))
  48. else:
  49. new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
  50. return new_grad
  51. class openpose_loss(_Loss):
  52. def __init__(self):
  53. super(openpose_loss, self).__init__()
  54. self.expand_dims = P.ExpandDims()
  55. self.tile = P.Tile()
  56. self.mul = P.Mul()
  57. self.l2_loss = P.L2Loss()
  58. self.square = P.Square()
  59. self.reduceMean = P.ReduceMean()
  60. self.reduceSum = P.ReduceSum()
  61. self.print = P.Print()
  62. self.shape = P.Shape()
  63. self.maxoftensor = P.ArgMaxWithValue(-1)
  64. def mean_square_error(self, map1, map2, mask=None):
  65. # print("mask", mask)
  66. # import pdb; pdb.set_trace()
  67. if mask is None:
  68. mse = self.reduceMean((map1 - map2) ** 2)
  69. return mse
  70. squareMap = self.square(map1 - map2)
  71. squareMap_mask = self.mul(squareMap, mask)
  72. mse = self.reduceMean(squareMap_mask)
  73. return mse
  74. def construct(self, logit_paf, logit_heatmap, gt_paf, gt_heatmap, ignore_mask):
  75. # Input
  76. # ignore_mask, make sure the ignore_mask the 0-1 array instead of the bool-false array
  77. heatmaps_loss = []
  78. pafs_loss = []
  79. total_loss = 0
  80. paf_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_paf)[1], 1, 1))
  81. heatmap_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_heatmap)[1], 1, 1))
  82. paf_masks = F.stop_gradient(paf_masks)
  83. heatmap_masks = F.stop_gradient(heatmap_masks)
  84. for logit_paf_t, logit_heatmap_t in zip(logit_paf, logit_heatmap):
  85. # TEST
  86. # tensor1 -- tuple
  87. # tensor1 = self.maxoftensor(logit_paf_t)[1]
  88. # tensor2 = self.maxoftensor(logit_heatmap_t)[1]
  89. # tensor3 = self.maxoftensor(tensor1)[1]
  90. # tensor4 = self.maxoftensor(tensor2)[1]
  91. # self.print("paf",tensor3)
  92. # self.print("heatmaps",tensor2)
  93. pafs_loss_t = self.mean_square_error(logit_paf_t, gt_paf, paf_masks)
  94. heatmaps_loss_t = self.mean_square_error(logit_heatmap_t, gt_heatmap, heatmap_masks)
  95. total_loss += pafs_loss_t + heatmaps_loss_t
  96. heatmaps_loss.append(heatmaps_loss_t)
  97. pafs_loss.append(pafs_loss_t)
  98. return total_loss, heatmaps_loss, pafs_loss
  99. class BuildTrainNetwork(nn.Cell):
  100. def __init__(self, network, criterion):
  101. super(BuildTrainNetwork, self).__init__()
  102. self.network = network
  103. self.criterion = criterion
  104. def construct(self, input_data, gt_paf, gt_heatmap, mask):
  105. logit_pafs, logit_heatmap = self.network(input_data)
  106. loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
  107. return loss
  108. #loss = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
  109. # return loss, heatmaps_loss, pafs_loss
  110. class TrainOneStepWithClipGradientCell(nn.Cell):
  111. '''TrainOneStepWithClipGradientCell'''
  112. def __init__(self, network, optimizer, sens=1.0):
  113. super(TrainOneStepWithClipGradientCell, self).__init__(auto_prefix=False)
  114. self.network = network
  115. self.network.set_grad()
  116. self.network.add_flags(defer_inline=True)
  117. self.weights = optimizer.parameters
  118. self.optimizer = optimizer
  119. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  120. self.hyper_map = C.HyperMap()
  121. self.sens = sens
  122. self.reducer_flag = False
  123. self.grad_reducer = None
  124. parallel_mode = _get_parallel_mode()
  125. if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
  126. self.reducer_flag = True
  127. if self.reducer_flag:
  128. mean = _get_gradients_mean()
  129. degree = _get_device_num()
  130. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  131. def construct(self, *inputs):
  132. weights = self.weights
  133. loss = self.network(*inputs)
  134. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  135. grads = self.grad(self.network, weights)(*inputs, sens)
  136. grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
  137. if self.reducer_flag:
  138. # apply grad reducer on grads
  139. grads = self.grad_reducer(grads)
  140. return F.depend(loss, self.optimizer(grads))