Browse Source

solve the problem of sudden increases in losses for fasterrcnn model

pull/14379/head
zhouneng 4 years ago
parent
commit
a7847cb612
5 changed files with 24 additions and 6 deletions
  1. +5
    -0
      model_zoo/official/cv/faster_rcnn/eval.py
  2. +4
    -0
      model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py
  3. +1
    -1
      model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py
  4. +10
    -5
      model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py
  5. +4
    -0
      model_zoo/official/cv/faster_rcnn/train.py

+ 5
- 0
model_zoo/official/cv/faster_rcnn/eval.py View File

@@ -19,6 +19,7 @@ import argparse
import time import time
import numpy as np import numpy as np
from pycocotools.coco import COCO from pycocotools.coco import COCO
import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed, Parameter from mindspore.common import set_seed, Parameter
@@ -51,7 +52,11 @@ def fasterrcnn_eval(dataset_path, ckpt_path, ann_file):
tensor = value.asnumpy().astype(np.float32) tensor = value.asnumpy().astype(np.float32)
param_dict[key] = Parameter(tensor, key) param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)

net.set_train(False) net.set_train(False)
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
if device_type == "Ascend":
net.to_float(mstype.float16)


eval_iter = 0 eval_iter = 0
total = ds.get_dataset_size() total = ds.get_dataset_size()


+ 4
- 0
model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py View File

@@ -16,6 +16,7 @@


import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@@ -144,6 +145,7 @@ class Faster_Rcnn_Resnet50(nn.Cell):


# Init tensor # Init tensor
self.init_tensor(config) self.init_tensor(config)
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"


def roi_init(self, config): def roi_init(self, config):
self.roi_align = SingleRoIExtractor(config, self.roi_align = SingleRoIExtractor(config,
@@ -267,6 +269,8 @@ class Faster_Rcnn_Resnet50(nn.Cell):
bboxes_all = self.concat(bboxes_tuple) bboxes_all = self.concat(bboxes_tuple)
else: else:
bboxes_all = bboxes_tuple[0] bboxes_all = bboxes_tuple[0]
if self.device_type == "Ascend":
bboxes_all = self.cast(bboxes_all, mstype.float16)
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))


rois = self.cast(rois, mstype.float32) rois = self.cast(rois, mstype.float32)


+ 1
- 1
model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rcnn.py View File

@@ -40,7 +40,7 @@ class DenseNoTranpose(nn.Cell):
if self.device_type == "Ascend": if self.device_type == "Ascend":
x = self.cast(x, mstype.float16) x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16) weight = self.cast(self.weight, mstype.float16)
output = self.bias_add(self.cast(self.matmul(x, weight), mstype.float32), self.bias)
output = self.bias_add(self.matmul(x, weight), self.bias)
else: else:
output = self.bias_add(self.matmul(x, self.weight), self.bias) output = self.bias_add(self.matmul(x, self.weight), self.bias)
return output return output


+ 10
- 5
model_zoo/official/cv/faster_rcnn/src/FasterRcnn/rpn.py View File

@@ -16,7 +16,7 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import context, Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
@@ -102,6 +102,7 @@ class RPN(nn.Cell):
cfg_rpn = config cfg_rpn = config
self.dtype = np.float32 self.dtype = np.float32
self.ms_type = mstype.float32 self.ms_type = mstype.float32
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
self.num_bboxes = cfg_rpn.num_bboxes self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = () self.slice_index = ()
self.feature_anchor_shape = () self.feature_anchor_shape = ()
@@ -180,9 +181,12 @@ class RPN(nn.Cell):
bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor() bias_reg = initializer(0, shape=shp_bias_reg, dtype=self.ms_type).to_tensor()


for i in range(num_layers): for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \ weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg))
bias_cls, weight_reg, bias_reg)
if self.device_type == "Ascend":
rpn_reg_cls_block.to_float(mstype.float16)
rpn_layer.append(rpn_reg_cls_block)


for i in range(1, num_layers): for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
@@ -250,6 +254,7 @@ class RPN(nn.Cell):
mstype.bool_), mstype.bool_),
anchor_using_list, gt_valids_i) anchor_using_list, gt_valids_i)


bbox_target = self.cast(bbox_target, self.ms_type)
bbox_weight = self.cast(bbox_weight, self.ms_type) bbox_weight = self.cast(bbox_weight, self.ms_type)
label = self.cast(label, self.ms_type) label = self.cast(label, self.ms_type)
label_weight = self.cast(label_weight, self.ms_type) label_weight = self.cast(label_weight, self.ms_type)
@@ -286,8 +291,8 @@ class RPN(nn.Cell):
label_ = F.stop_gradient(label_with_batchsize) label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize) label_weight_ = F.stop_gradient(label_weight_with_batchsize)


cls_score_i = rpn_cls_score[i]
reg_score_i = rpn_bbox_pred[i]
cls_score_i = self.cast(rpn_cls_score[i], self.ms_type)
reg_score_i = self.cast(rpn_bbox_pred[i], self.ms_type)


loss_cls = self.loss_cls(cls_score_i, label_) loss_cls = self.loss_cls(cls_score_i, label_)
loss_cls_item = loss_cls * label_weight_ loss_cls_item = loss_cls * label_weight_


+ 4
- 0
model_zoo/official/cv/faster_rcnn/train.py View File

@@ -152,6 +152,10 @@ if __name__ == '__main__':
param_dict[key] = Parameter(tensor, key) param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
if device_type == "Ascend":
net.to_float(mstype.float16)

loss = LossNet() loss = LossNet()
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32) lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)




Loading…
Cancel
Save