|
- # Copyright 2020 Huawei Technologies Co., Ltd
-
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
-
- # http://www.apache.org/licenses/LICENSE-2.0
-
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import mindspore.nn as nn
- from mindspore.ops import operations as P
- from mindspore.nn.loss.loss import _Loss
- from mindspore.ops import functional as F
- from mindspore.ops import composite as C
- from mindspore.context import ParallelMode
- from mindspore import context
- from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
- from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
-
- from src.config import params
-
- context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
- time_stamp_init = False
- time_stamp_first = 0
- grad_scale = C.MultitypeFuncGraph("grad_scale")
- _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
- reciprocal = P.Reciprocal()
-
- GRADIENT_CLIP_TYPE = params['GRADIENT_CLIP_TYPE']
- GRADIENT_CLIP_VALUE = params['GRADIENT_CLIP_VALUE']
-
- clip_grad = C.MultitypeFuncGraph("clip_grad")
-
- @clip_grad.register("Number", "Number", "Tensor")
- def _clip_grad(clip_type, clip_value, grad):
- """
- Clip gradients.
-
- Inputs:
- clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
- clip_value (float): Specifies how much to clip.
- grad (tuple[Tensor]): Gradients.
-
- Outputs:
- tuple[Tensor]: clipped gradients.
- """
- if clip_type not in (0, 1):
- return grad
- dt = F.dtype(grad)
- if clip_type == 0:
- new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
- F.cast(F.tuple_to_array((clip_value,)), dt))
- else:
- new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
- return new_grad
-
- class openpose_loss(_Loss):
- def __init__(self):
- super(openpose_loss, self).__init__()
- self.expand_dims = P.ExpandDims()
- self.tile = P.Tile()
- self.mul = P.Mul()
- self.l2_loss = P.L2Loss()
- self.square = P.Square()
- self.reduceMean = P.ReduceMean()
- self.reduceSum = P.ReduceSum()
- self.print = P.Print()
- self.shape = P.Shape()
- self.maxoftensor = P.ArgMaxWithValue(-1)
-
- def mean_square_error(self, map1, map2, mask=None):
- # print("mask", mask)
- # import pdb; pdb.set_trace()
- if mask is None:
- mse = self.reduceMean((map1 - map2) ** 2)
- return mse
-
- squareMap = self.square(map1 - map2)
- squareMap_mask = self.mul(squareMap, mask)
- mse = self.reduceMean(squareMap_mask)
- return mse
-
- def construct(self, logit_paf, logit_heatmap, gt_paf, gt_heatmap, ignore_mask):
- # Input
- # ignore_mask, make sure the ignore_mask the 0-1 array instead of the bool-false array
- heatmaps_loss = []
- pafs_loss = []
- total_loss = 0
-
- paf_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_paf)[1], 1, 1))
- heatmap_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_heatmap)[1], 1, 1))
-
- paf_masks = F.stop_gradient(paf_masks)
- heatmap_masks = F.stop_gradient(heatmap_masks)
- for logit_paf_t, logit_heatmap_t in zip(logit_paf, logit_heatmap):
- # TEST
- # tensor1 -- tuple
- # tensor1 = self.maxoftensor(logit_paf_t)[1]
- # tensor2 = self.maxoftensor(logit_heatmap_t)[1]
- # tensor3 = self.maxoftensor(tensor1)[1]
- # tensor4 = self.maxoftensor(tensor2)[1]
- # self.print("paf",tensor3)
- # self.print("heatmaps",tensor2)
- pafs_loss_t = self.mean_square_error(logit_paf_t, gt_paf, paf_masks)
- heatmaps_loss_t = self.mean_square_error(logit_heatmap_t, gt_heatmap, heatmap_masks)
-
- total_loss += pafs_loss_t + heatmaps_loss_t
- heatmaps_loss.append(heatmaps_loss_t)
- pafs_loss.append(pafs_loss_t)
-
- return total_loss, heatmaps_loss, pafs_loss
-
- class BuildTrainNetwork(nn.Cell):
- def __init__(self, network, criterion):
- super(BuildTrainNetwork, self).__init__()
- self.network = network
- self.criterion = criterion
-
- def construct(self, input_data, gt_paf, gt_heatmap, mask):
- logit_pafs, logit_heatmap = self.network(input_data)
- loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
- return loss
- #loss = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
- # return loss, heatmaps_loss, pafs_loss
-
- class TrainOneStepWithClipGradientCell(nn.Cell):
- '''TrainOneStepWithClipGradientCell'''
- def __init__(self, network, optimizer, sens=1.0):
- super(TrainOneStepWithClipGradientCell, self).__init__(auto_prefix=False)
- self.network = network
- self.network.set_grad()
- self.network.add_flags(defer_inline=True)
- self.weights = optimizer.parameters
- self.optimizer = optimizer
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.hyper_map = C.HyperMap()
- self.sens = sens
- self.reducer_flag = False
- self.grad_reducer = None
- parallel_mode = _get_parallel_mode()
- if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
- self.reducer_flag = True
- if self.reducer_flag:
- mean = _get_gradients_mean()
- degree = _get_device_num()
- self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-
- def construct(self, *inputs):
-
- weights = self.weights
- loss = self.network(*inputs)
- sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
- grads = self.grad(self.network, weights)(*inputs, sens)
- grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
- if self.reducer_flag:
- # apply grad reducer on grads
- grads = self.grad_reducer(grads)
- return F.depend(loss, self.optimizer(grads))
|