Browse Source

!14171 Fix FaceAttribute net bug

From: @zhanghuiyao
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
pull/14171/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
44d6da63c2
1 changed files with 11 additions and 13 deletions
  1. +11
    -13
      model_zoo/research/cv/FaceAttribute/train.py

+ 11
- 13
model_zoo/research/cv/FaceAttribute/train.py View File

@@ -18,6 +18,7 @@ import time
import datetime import datetime
import argparse import argparse


import mindspore
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
@@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs


class BuildTrainNetwork(nn.Cell): class BuildTrainNetwork(nn.Cell):
'''Build train network.''' '''Build train network.'''
def __init__(self, network, criterion):
def __init__(self, my_network, my_criterion):
super(BuildTrainNetwork, self).__init__() super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
self.network = my_network
self.criterion = my_criterion
self.print = P.Print() self.print = P.Print()


def construct(self, input_data, label): def construct(self, input_data, label):
logit0, logit1, logit2 = self.network(input_data) logit0, logit1, logit2 = self.network(input_data)
loss = self.criterion(logit0, logit1, logit2, label)
return loss
loss0 = self.criterion(logit0, logit1, logit2, label)
return loss0




def parse_args(): def parse_args():
@@ -64,13 +65,14 @@ def parse_args():
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed') parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')


args, _ = parser.parse_known_args()
arg, _ = parser.parse_known_args()


return args
return arg




def train():
'''train function.'''
if __name__ == "__main__":
mindspore.set_seed(1)

# logger # logger
args = parse_args() args = parse_args()


@@ -226,7 +228,3 @@ def train():
i += 1 i += 1


args.logger.info('--------- trains out ---------') args.logger.info('--------- trains out ---------')


if __name__ == "__main__":
train()

Loading…
Cancel
Save