Browse Source

modify

tags/v0.5.0-beta
unknown 5 years ago
parent
commit
913b5b03df
5 changed files with 20 additions and 36 deletions
  1. +8
    -0
      model_zoo/deeplabv3/src/backbone/__init__.py
  2. +3
    -0
      model_zoo/deeplabv3/src/backbone/resnet_deeplab.py
  3. +1
    -24
      model_zoo/deeplabv3/src/ei_datasest.py
  4. +2
    -6
      model_zoo/deeplabv3/src/md_dataset.py
  5. +6
    -6
      model_zoo/deeplabv3/train.py

+ 8
- 0
model_zoo/deeplabv3/src/backbone/__init__.py View File

@@ -0,0 +1,8 @@
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
RootBlockBeta, resnet50_dl
__all__= [
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta",
"resnet50_dl"
]

+ 3
- 0
model_zoo/deeplabv3/src/backbone/resnet_deeplab.py View File

@@ -532,3 +532,6 @@ class RootBlockBeta(nn.Cell):
x = self.conv2(x)
x = self.conv3(x)
return x
class resnet50_dl(fine_tune_batch_norm=False):
return ResNetV1(fine_tune_batch_norm)

+ 1
- 24
model_zoo/deeplabv3/src/ei_datasest.py View File

@@ -17,7 +17,7 @@ import abc
import os
import time

from .utils.adapter import get_manifest_samples, get_raw_samples, read_image
from .utils.adapter import get_raw_samples, read_image


class BaseDataset(object):
@@ -62,29 +62,6 @@ class BaseDataset(object):
pass


class HwVocManifestDataset(BaseDataset):
"""
Create dataset with manifest data.

Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').

Returns:
Dataset.
"""

def __init__(self, data_url, usage="train"):
super().__init__(data_url, usage)

def _load_samples(self):
try:
self.samples = get_manifest_samples(self.data_url, self.usage)
except Exception as e:
print("load HwVocManifestDataset samples failed!!!")
raise e


class HwVocRawDataset(BaseDataset):
"""
Create dataset with raw data.


+ 2
- 6
model_zoo/deeplabv3/src/md_dataset.py View File

@@ -17,7 +17,7 @@ from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C

from .ei_dataset import HwVocManifestDataset, HwVocRawDataset
from .ei_dataset import HwVocRawDataset
from .utils import custom_transforms as tr


@@ -77,10 +77,7 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
Dataset.
"""
# create iter dataset
if data_url.endswith(".manifest"):
dataset = HwVocManifestDataset(data_url, usage=usage)
else:
dataset = HwVocRawDataset(data_url, usage=usage)
dataset = HwVocRawDataset(data_url, usage=usage)
dataset_len = len(dataset)

# wrapped with GeneratorDataset
@@ -100,5 +97,4 @@ def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train"):
dataset = dataset.repeat(count=epoch_num)
dataset.map_model = 4

dataset.__loop_size__ = 1
return dataset

+ 6
- 6
model_zoo/deeplabv3/train.py View File

@@ -87,13 +87,13 @@ if __name__ == "__main__":
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
callback.append(ckpoint_cb)
net = deeplabv3_resnet50(crop_size.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
infer_scale_sizes=crop_size.eval_scales, atrous_rates=crop_size.atrous_rates,
decoder_output_stride=crop_size.decoder_output_stride, output_stride = crop_size.output_stride,
fine_tune_batch_norm=crop_size.fine_tune_batch_norm, image_pyramid = crop_size.image_pyramid)
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid)
net.set_train()
model_fine_tune(args_opt, net, 'layer')
loss = OhemLoss(crop_size.seg_num_classes, crop_size.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=args_opt.learning_rate, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay)
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
model = Model(net, loss, opt)
model.train(args_opt.epoch_size, train_dataset, callback)

Loading…
Cancel
Save