Browse Source

!9163 Remove network useless parameters .

From: @linqingke
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7c4e74ac00
8 changed files with 15 additions and 24 deletions
  1. +1
    -1
      model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py
  2. +2
    -1
      model_zoo/official/cv/faster_rcnn/src/FasterRcnn/roi_align.py
  3. +11
    -11
      model_zoo/official/cv/psenet/src/config.py
  4. +1
    -2
      model_zoo/official/nlp/mass/config/config.json
  5. +0
    -4
      model_zoo/official/nlp/mass/config/config.py
  6. +0
    -3
      model_zoo/official/nlp/mass/src/transformer/create_attn_mask.py
  7. +0
    -1
      model_zoo/official/nlp/mass/src/transformer/transformer.py
  8. +0
    -1
      model_zoo/official/nlp/mass/src/transformer/transformer_for_infer.py

+ 1
- 1
model_zoo/official/cv/faster_rcnn/src/FasterRcnn/faster_rcnn_r50.py View File

@@ -111,7 +111,7 @@ class Faster_Rcnn_Resnet50(nn.Cell):
# Assign and sampler stage two # Assign and sampler stage two
self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size, self.bbox_assigner_sampler_for_rcnn = BboxAssignSampleForRcnn(config, self.train_batch_size,
config.num_bboxes_stage2, True) config.num_bboxes_stage2, True)
self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=self.target_means, \
self.decode = P.BoundingBoxDecode(max_shape=(config.img_height, config.img_width), means=self.target_means, \
stds=self.target_stds) stds=self.target_stds)


# Roi # Roi


+ 2
- 1
model_zoo/official/cv/faster_rcnn/src/FasterRcnn/roi_align.py View File

@@ -172,7 +172,8 @@ class SingleRoIExtractor(nn.Cell):
mask = self.equal(target_lvls, P.ScalarToArray()(i)) mask = self.equal(target_lvls, P.ScalarToArray()(i))
mask = P.Reshape()(mask, (-1, 1, 1, 1)) mask = P.Reshape()(mask, (-1, 1, 1, 1))
roi_feats_t = self.roi_layers[i](feats[i], rois) roi_feats_t = self.roi_layers[i](feats[i], rois)
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, 7, 7)), mstype.bool_)
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32),\
(1, 256, self.out_size, self.out_size)), mstype.bool_)
res = self.select(mask, roi_feats_t, res) res = self.select(mask, roi_feats_t, res)


return res return res

+ 11
- 11
model_zoo/official/cv/psenet/src/config.py View File

@@ -17,17 +17,17 @@
from easydict import EasyDict as ed from easydict import EasyDict as ed
config = ed({ config = ed({
'INFER_LONG_SIZE': 1920,
'KERNEL_NUM': 7,
'INFERENCE': True, # INFER MODE\TRAIN MODE
"INFER_LONG_SIZE": 1920,
"KERNEL_NUM": 7,
"INFERENCE": True, # INFER MODE\TRAIN MODE
# backbone # backbone
'BACKBONE_LAYER_NUMS': [3, 4, 6, 3],
'BACKBONE_IN_CHANNELS': [64, 256, 512, 1024],
'BACKBONE_OUT_CHANNELS': [256, 512, 1024, 2048],
"BACKBONE_LAYER_NUMS": [3, 4, 6, 3],
"BACKBONE_IN_CHANNELS": [64, 256, 512, 1024],
"BACKBONE_OUT_CHANNELS": [256, 512, 1024, 2048],
# neck # neck
'NECK_OUT_CHANNEL': 256,
"NECK_OUT_CHANNEL": 256,
# lr # lr
"BASE_LR": 2e-3, "BASE_LR": 2e-3,
@@ -36,20 +36,20 @@ config = ed({
"WARMUP_RATIO": 1/3, "WARMUP_RATIO": 1/3,
# dataset for train # dataset for train
"TRAIN_ROOT_DIR": 'psenet/ic15/',
"TRAIN_ROOT_DIR": "psenet/ic15/",
"TRAIN_IS_TRANSFORM": True, "TRAIN_IS_TRANSFORM": True,
"TRAIN_LONG_SIZE": 640, "TRAIN_LONG_SIZE": 640,
"TRAIN_MIN_SCALE": 0.4, "TRAIN_MIN_SCALE": 0.4,
"TRAIN_BATCH_SIZE": 4, "TRAIN_BATCH_SIZE": 4,
"TRAIN_REPEAT_NUM": 1800, "TRAIN_REPEAT_NUM": 1800,
"TRAIN_DROP_REMAINDER": True, "TRAIN_DROP_REMAINDER": True,
"TRAIN_MODEL_SAVE_PATH": './checkpoints/',
"TRAIN_MODEL_SAVE_PATH": "./checkpoints/",
# dataset for test # dataset for test
"TEST_ROOT_DIR": 'psenet/ic15/',
"TEST_ROOT_DIR": "psenet/ic15/",
"TEST_BUFFER_SIZE": 4, "TEST_BUFFER_SIZE": 4,
"TEST_DROP_REMAINDER": False, "TEST_DROP_REMAINDER": False,
# air config # air config
'air_filename': 'psenet_bs_1.air',
"air_filename": "psenet_bs_1.air",
}) })

+ 1
- 2
model_zoo/official/nlp/mass/config/config.json View File

@@ -26,8 +26,7 @@
"label_smoothing": 0.1, "label_smoothing": 0.1,
"beam_width": 4, "beam_width": 4,
"length_penalty_weight": 1.0, "length_penalty_weight": 1.0,
"max_decode_length": 64,
"input_mask_from_dataset": true
"max_decode_length": 64
}, },
"loss_scale_config": { "loss_scale_config": {
"loss_scale_mode": "dynamic", "loss_scale_mode": "dynamic",


+ 0
- 4
model_zoo/official/nlp/mass/config/config.py View File

@@ -100,8 +100,6 @@ class TransformerConfig:
beam_width (int): Beam width for beam search in inferring. Default: 4. beam_width (int): Beam width for beam search in inferring. Default: 4.
length_penalty_weight (float): Penalty for sentence length. Default: 1.0. length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
label_smoothing (float): Label smoothing setting. Default: 0.1. label_smoothing (float): Label smoothing setting. Default: 0.1.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset. Default: True.
save_graphs (bool): Whether to save graphs, please set to True if mindinsight save_graphs (bool): Whether to save graphs, please set to True if mindinsight
is wanted. is wanted.
dtype (mstype): Data type of the input. Default: mstype.float32. dtype (mstype): Data type of the input. Default: mstype.float32.
@@ -148,7 +146,6 @@ class TransformerConfig:
beam_width=5, beam_width=5,
length_penalty_weight=1.0, length_penalty_weight=1.0,
label_smoothing=0.1, label_smoothing=0.1,
input_mask_from_dataset=True,
save_graphs=False, save_graphs=False,
dtype=mstype.float32, dtype=mstype.float32,
max_decode_length=64): max_decode_length=64):
@@ -190,7 +187,6 @@ class TransformerConfig:
self.beam_width = beam_width self.beam_width = beam_width
self.length_penalty_weight = length_penalty_weight self.length_penalty_weight = length_penalty_weight
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.input_mask_from_dataset = input_mask_from_dataset
self.compute_type = mstype.float16 self.compute_type = mstype.float16
self.dtype = dtype self.dtype = dtype




+ 0
- 3
model_zoo/official/nlp/mass/src/transformer/create_attn_mask.py View File

@@ -33,11 +33,8 @@ class CreateAttentionMaskFromInputMask(nn.Cell):


def __init__(self, config): def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__() super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask_from_dataset = config.input_mask_from_dataset
self.input_mask = None self.input_mask = None


assert self.input_mask_from_dataset

self.cast = P.Cast() self.cast = P.Cast()
self.shape = P.Shape() self.shape = P.Shape()
self.reshape = P.Reshape() self.reshape = P.Reshape()


+ 0
- 1
model_zoo/official/nlp/mass/src/transformer/transformer.py View File

@@ -59,7 +59,6 @@ class Transformer(nn.Cell):
config.hidden_dropout_prob = 0.0 config.hidden_dropout_prob = 0.0
config.attention_dropout_prob = 0.0 config.attention_dropout_prob = 0.0


self.input_mask_from_dataset = config.input_mask_from_dataset
self.batch_size = config.batch_size self.batch_size = config.batch_size
self.max_positions = config.seq_length self.max_positions = config.seq_length
self.attn_embed_dim = config.hidden_size self.attn_embed_dim = config.hidden_size


+ 0
- 1
model_zoo/official/nlp/mass/src/transformer/transformer_for_infer.py View File

@@ -212,7 +212,6 @@ class TransformerInferModel(nn.Cell):
config.hidden_dropout_prob = 0.0 config.hidden_dropout_prob = 0.0
config.attention_dropout_prob = 0.0 config.attention_dropout_prob = 0.0


self.input_mask_from_dataset = config.input_mask_from_dataset
self.batch_size = config.batch_size self.batch_size = config.batch_size
self.seq_length = config.seq_length self.seq_length = config.seq_length
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size


Loading…
Cancel
Save