Browse Source

!14171 Fix FaceAttribute net bug

From: @zhanghuiyao
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
pull/14171/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
44d6da63c2
1 changed files with 11 additions and 13 deletions
  1. +11
    -13
      model_zoo/research/cv/FaceAttribute/train.py

+ 11
- 13
model_zoo/research/cv/FaceAttribute/train.py View File

@@ -18,6 +18,7 @@ import time
import datetime
import argparse

import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
@@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs

class BuildTrainNetwork(nn.Cell):
'''Build train network.'''
def __init__(self, network, criterion):
def __init__(self, my_network, my_criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
self.network = my_network
self.criterion = my_criterion
self.print = P.Print()

def construct(self, input_data, label):
logit0, logit1, logit2 = self.network(input_data)
loss = self.criterion(logit0, logit1, logit2, label)
return loss
loss0 = self.criterion(logit0, logit1, logit2, label)
return loss0


def parse_args():
@@ -64,13 +65,14 @@ def parse_args():
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')

args, _ = parser.parse_known_args()
arg, _ = parser.parse_known_args()

return args
return arg


def train():
'''train function.'''
if __name__ == "__main__":
mindspore.set_seed(1)

# logger
args = parse_args()

@@ -226,7 +228,3 @@ def train():
i += 1

args.logger.info('--------- trains out ---------')


if __name__ == "__main__":
train()

Loading…
Cancel
Save