You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face Quality Assessment loss."""
  16. import mindspore.nn as nn
  17. from mindspore.ops import operations as P
  18. from mindspore.ops import functional as F
  19. from mindspore.common import dtype as mstype
  20. from mindspore.nn.loss.loss import _Loss
  21. from mindspore import Tensor
  22. eps = 1e-24
  23. class CEWithIgnoreIndex3D(_Loss):
  24. '''CEWithIgnoreIndex3D'''
  25. def __init__(self):
  26. super(CEWithIgnoreIndex3D, self).__init__()
  27. self.exp = P.Exp()
  28. self.sum = P.ReduceSum()
  29. self.reshape = P.Reshape()
  30. self.log = P.Log()
  31. self.cast = P.Cast()
  32. self.eps_const = Tensor(eps, dtype=mstype.float32)
  33. self.ones = P.OnesLike()
  34. self.onehot = P.OneHot()
  35. self.on_value = Tensor(1.0, mstype.float32)
  36. self.off_value = Tensor(0.0, mstype.float32)
  37. self.relu = P.ReLU()
  38. self.maximum = P.Maximum()
  39. self.resum = P.ReduceSum(keep_dims=False)
  40. def construct(self, logit, label):
  41. '''construct'''
  42. mask = self.reshape(label, (F.shape(label)[0], F.shape(label)[1], 1))
  43. mask = self.cast(mask, mstype.float32)
  44. mask = mask + F.scalar_to_array(0.00001)
  45. mask = self.relu(mask) / (mask)
  46. logit = logit * mask
  47. exp = self.exp(logit)
  48. exp_sum = self.sum(exp, -1)
  49. exp_sum = self.reshape(exp_sum, (F.shape(exp_sum)[0], F.shape(exp_sum)[1], 1))
  50. softmax_result = self.log(exp / exp_sum + self.eps_const)
  51. one_hot_label = self.onehot(
  52. self.cast(label, mstype.int32), F.shape(logit)[2], self.on_value, self.off_value)
  53. loss = (softmax_result * self.cast(one_hot_label, mstype.float32) * self.cast(F.scalar_to_array(-1),
  54. mstype.float32))
  55. loss = self.sum(loss, -1)
  56. loss = self.sum(loss, -1)
  57. loss = self.sum(loss, 0)
  58. loss = loss
  59. return loss
  60. class CriterionsFaceQA(nn.Cell):
  61. '''CriterionsFaceQA'''
  62. def __init__(self):
  63. super(CriterionsFaceQA, self).__init__()
  64. self.gatherv2 = P.Gather()
  65. self.squeeze = P.Squeeze(axis=1)
  66. self.shape = P.Shape()
  67. self.reshape = P.Reshape()
  68. self.euler_label_list = Tensor([0, 1, 2], dtype=mstype.int32)
  69. self.mse_loss = nn.MSELoss(reduction='sum')
  70. self.kp_label_list = Tensor([3, 4, 5, 6, 7], dtype=mstype.int32)
  71. self.kps_loss = CEWithIgnoreIndex3D()
  72. def construct(self, x1, x2, label):
  73. '''construct'''
  74. # euler
  75. euler_label = self.gatherv2(label, self.euler_label_list, 1)
  76. loss_euler = self.mse_loss(x1, euler_label)
  77. # key points
  78. b, _, _, _ = self.shape(x2)
  79. x2 = self.reshape(x2, (b, 5, 48 * 48))
  80. kps_label = self.gatherv2(label, self.kp_label_list, 1)
  81. loss_kps = self.kps_loss(x2, kps_label)
  82. loss_tot = (loss_kps + loss_euler) / b
  83. return loss_tot