Browse Source

!14170 Fix FaceAttribute net bug

From: @zhanghuiyao
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
tags/v1.2.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
55db218a79
1 changed files with 10 additions and 13 deletions
  1. +10
    -13
      model_zoo/research/cv/FaceAttribute/train.py

+ 10
- 13
model_zoo/research/cv/FaceAttribute/train.py View File

@@ -18,6 +18,7 @@ import time
import datetime
import argparse

import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
@@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs

class BuildTrainNetwork(nn.Cell):
'''Build train network.'''
def __init__(self, network, criterion):
def __init__(self, my_network, my_criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
self.network = my_network
self.criterion = my_criterion
self.print = P.Print()

def construct(self, input_data, label):
logit0, logit1, logit2 = self.network(input_data)
loss = self.criterion(logit0, logit1, logit2, label)
return loss
loss0 = self.criterion(logit0, logit1, logit2, label)
return loss0


def parse_args():
@@ -64,13 +65,13 @@ def parse_args():
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')

args, _ = parser.parse_known_args()
arg, _ = parser.parse_known_args()

return args
return arg

if __name__ == "__main__":
mindspore.set_seed(1)

def train():
'''train function.'''
# logger
args = parse_args()

@@ -226,7 +227,3 @@ def train():
i += 1

args.logger.info('--------- trains out ---------')


if __name__ == "__main__":
train()

Loading…
Cancel
Save