|
|
|
@@ -31,31 +31,46 @@ class ImageClassificationNetwork(nn.Cell): |
|
|
|
Returns: |
|
|
|
Tensor, output tensor. |
|
|
|
""" |
|
|
|
def __init__(self, backbone, head): |
|
|
|
def __init__(self, backbone, head, include_top=True, activation="None"): |
|
|
|
super(ImageClassificationNetwork, self).__init__() |
|
|
|
self.backbone = backbone |
|
|
|
self.head = head |
|
|
|
self.include_top = include_top |
|
|
|
self.need_activation = False |
|
|
|
if self.include_top: |
|
|
|
self.head = head |
|
|
|
if activation != "None": |
|
|
|
self.need_activation = True |
|
|
|
if activation == "Sigmoid": |
|
|
|
self.activation = P.Sigmoid() |
|
|
|
elif activation == "Softmax": |
|
|
|
self.activation = P.Softmax() |
|
|
|
else: |
|
|
|
raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].") |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.backbone(x) |
|
|
|
x = self.head(x) |
|
|
|
if self.include_top: |
|
|
|
x = self.head(x) |
|
|
|
if self.need_activation: |
|
|
|
x = self.activation(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Resnet(ImageClassificationNetwork): |
|
|
|
""" |
|
|
|
Resnet architecture. |
|
|
|
Args: |
|
|
|
backbone_name (string): backbone. |
|
|
|
num_classes (int): number of classes. |
|
|
|
num_classes (int): number of classes, Default is 1000. |
|
|
|
Returns: |
|
|
|
Resnet. |
|
|
|
""" |
|
|
|
def __init__(self, backbone_name, num_classes, platform="Ascend"): |
|
|
|
def __init__(self, backbone_name, num_classes=1000, platform="Ascend", include_top=True, activation="None"): |
|
|
|
self.backbone_name = backbone_name |
|
|
|
backbone = backbones.__dict__[self.backbone_name](platform=platform) |
|
|
|
out_channels = backbone.get_out_channels() |
|
|
|
head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) |
|
|
|
super(Resnet, self).__init__(backbone, head) |
|
|
|
super(Resnet, self).__init__(backbone, head, include_top, activation) |
|
|
|
|
|
|
|
default_recurisive_init(self) |
|
|
|
|
|
|
|
@@ -79,7 +94,7 @@ class Resnet(ImageClassificationNetwork): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_network(backbone_name, num_classes, platform="Ascend"): |
|
|
|
def get_network(backbone_name, **kwargs): |
|
|
|
if backbone_name in ['resnext50']: |
|
|
|
return Resnet(backbone_name, num_classes, platform) |
|
|
|
return Resnet(backbone_name, **kwargs) |
|
|
|
return None |