# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Xception.""" import mindspore.nn as nn import mindspore.ops.operations as P class SeparableConv2d(nn.Cell): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0): super(SeparableConv2d, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels, pad_mode='pad', padding=padding, weight_init='xavier_uniform') self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='valid', weight_init='xavier_uniform') def construct(self, x): x = self.conv1(x) x = self.pointwise(x) return x class Block(nn.Cell): def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): super(Block, self).__init__() if out_filters != in_filters or strides != 1: self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, pad_mode='valid', has_bias=False, weight_init='xavier_uniform') self.skipbn = nn.BatchNorm2d(out_filters, momentum=0.9) else: self.skip = None self.relu = nn.ReLU() rep = [] filters = in_filters if grow_first: rep.append(nn.ReLU()) rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) rep.append(nn.BatchNorm2d(out_filters, momentum=0.9)) filters = out_filters for _ in range(reps - 1): rep.append(nn.ReLU()) rep.append(SeparableConv2d(filters, filters, kernel_size=3, stride=1, padding=1)) rep.append(nn.BatchNorm2d(filters, momentum=0.9)) if not grow_first: rep.append(nn.ReLU()) rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) rep.append(nn.BatchNorm2d(out_filters, momentum=0.9)) if not start_with_relu: rep = rep[1:] else: rep[0] = nn.ReLU() if strides != 1: rep.append(nn.MaxPool2d(3, strides, pad_mode="same")) self.rep = nn.SequentialCell(*rep) self.add = P.Add() def construct(self, inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x = self.add(x, skip) return x class Xception(nn.Cell): """ Xception optimized for the ImageNet dataset, as specified in https://arxiv.org/abs/1610.02357.pdf """ def __init__(self, num_classes=1000): """ Constructor Args: num_classes: number of classes. """ super(Xception, self).__init__() self.num_classes = num_classes self.conv1 = nn.Conv2d(3, 32, 3, 2, pad_mode='valid', weight_init='xavier_uniform') self.bn1 = nn.BatchNorm2d(32, momentum=0.9) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='valid', weight_init='xavier_uniform') self.bn2 = nn.BatchNorm2d(64, momentum=0.9) # Entry flow self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) # Middle flow self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) # Exit flow self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) self.bn3 = nn.BatchNorm2d(1536, momentum=0.9) self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) self.bn4 = nn.BatchNorm2d(2048, momentum=0.9) self.avg_pool = nn.AvgPool2d(10) self.dropout = nn.Dropout() self.fc = nn.Dense(2048, num_classes) def construct(self, x): shape = P.Shape() reshape = P.Reshape() x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) x = self.block12(x) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) x = self.bn4(x) x = self.relu(x) x = self.avg_pool(x) x = self.dropout(x) x = reshape(x, (shape(x)[0], -1)) x = self.fc(x) return x def xception(class_num=1000): """ Get Xception neural network. Args: class_num (int): Class number. Returns: Cell, cell instance of Xception neural network. Examples: >>> net = xception(1000) """ return Xception(class_num)