|
|
@@ -18,6 +18,7 @@ import time |
|
|
import datetime |
|
|
import datetime |
|
|
import argparse |
|
|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
import mindspore |
|
|
import mindspore.nn as nn |
|
|
import mindspore.nn as nn |
|
|
from mindspore import context |
|
|
from mindspore import context |
|
|
from mindspore import Tensor |
|
|
from mindspore import Tensor |
|
|
@@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs |
|
|
|
|
|
|
|
|
class BuildTrainNetwork(nn.Cell): |
|
|
class BuildTrainNetwork(nn.Cell): |
|
|
'''Build train network.''' |
|
|
'''Build train network.''' |
|
|
def __init__(self, network, criterion): |
|
|
|
|
|
|
|
|
def __init__(self, my_network, my_criterion): |
|
|
super(BuildTrainNetwork, self).__init__() |
|
|
super(BuildTrainNetwork, self).__init__() |
|
|
self.network = network |
|
|
|
|
|
self.criterion = criterion |
|
|
|
|
|
|
|
|
self.network = my_network |
|
|
|
|
|
self.criterion = my_criterion |
|
|
self.print = P.Print() |
|
|
self.print = P.Print() |
|
|
|
|
|
|
|
|
def construct(self, input_data, label): |
|
|
def construct(self, input_data, label): |
|
|
logit0, logit1, logit2 = self.network(input_data) |
|
|
logit0, logit1, logit2 = self.network(input_data) |
|
|
loss = self.criterion(logit0, logit1, logit2, label) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
loss0 = self.criterion(logit0, logit1, logit2, label) |
|
|
|
|
|
return loss0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
def parse_args(): |
|
|
@@ -64,13 +65,14 @@ def parse_args(): |
|
|
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') |
|
|
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') |
|
|
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed') |
|
|
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed') |
|
|
|
|
|
|
|
|
args, _ = parser.parse_known_args() |
|
|
|
|
|
|
|
|
arg, _ = parser.parse_known_args() |
|
|
|
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
return arg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
|
|
|
'''train function.''' |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
mindspore.set_seed(1) |
|
|
|
|
|
|
|
|
# logger |
|
|
# logger |
|
|
args = parse_args() |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
@@ -226,7 +228,3 @@ def train(): |
|
|
i += 1 |
|
|
i += 1 |
|
|
|
|
|
|
|
|
args.logger.info('--------- trains out ---------') |
|
|
args.logger.info('--------- trains out ---------') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
train() |
|
|
|