# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Train NAML.""" import time from mindspore import nn, load_checkpoint import mindspore.common.dtype as mstype from mindspore.common import set_seed from mindspore.train.model import Model from mindspore.train.loss_scale_manager import DynamicLossScaleManager from src.naml import NAML, NAMLWithLossCell from src.option import get_args from src.dataset import create_dataset, MINDPreprocess from src.utils import process_data from src.callback import Monitor if __name__ == '__main__': args = get_args("train") set_seed(args.seed) word_embedding = process_data(args) net = NAML(args, word_embedding) net_with_loss = NAMLWithLossCell(net) if args.checkpoint_path is not None: load_checkpoint(args.pretrain_checkpoint, net_with_loss) mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path) dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size, rank=args.rank, group_size=args.device_num) args.dataset_size = dataset.get_dataset_size() args.print_times = min(args.dataset_size, args.print_times) if args.weight_decay: weight_params = list(filter(lambda x: 'weight' in x.name, net.trainable_params())) other_params = list(filter(lambda x: 'weight' not in x.name, net.trainable_params())) group_params = [{'params': weight_params, 'weight_decay': 1e-3}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': net.trainable_params()}] opt = nn.AdamWeightDecay(group_params, args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon) else: opt = nn.Adam(net.trainable_params(), args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon) if args.mixed: loss_scale_manager = DynamicLossScaleManager(init_loss_scale=128.0, scale_factor=2, scale_window=10000) net_with_loss.to_float(mstype.float16) for _, cell in net_with_loss.cells_and_names(): if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)): cell.to_float(mstype.float32) model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager) else: model = Model(net_with_loss, optimizer=opt) cb = [Monitor(args)] epochs = args.epochs if args.sink_mode: epochs = int(args.epochs * args.dataset_size / args.print_times) start_time = time.time() print("======================= Start Train ==========================", flush=True) model.train(epochs, dataset, callbacks=cb, dataset_sink_mode=args.sink_mode, sink_size=args.print_times) end_time = time.time() print("==============================================================") print("processor_name: {}".format(args.platform)) print("test_name: NAML") print(f"model_name: NAML MIND{args.dataset}") print("batch_size: {}".format(args.batch_size)) print("latency: {} s".format(end_time - start_time))