|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- '''
- ##############train models#################
- python train.py
- '''
- import argparse
- from mindspore import context, nn
- from mindspore.train import Model
- from mindspore.common import set_seed
- from mindspore.train.loss_scale_manager import FixedLossScaleManager
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from src.dataset import create_dataset
- from src.musictagger import MusicTaggerCNN
- from src.loss import BCELoss
- from src.config import music_cfg as cfg
-
- def train(model, dataset_direct, filename, columns_list, num_consumer=4,
- batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50,
- prefix="model", directory='./'):
- """
- train network
- """
- config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
- keep_checkpoint_max=keep_checkpoint_max)
- ckpoint_cb = ModelCheckpoint(prefix=prefix,
- directory=directory,
- config=config_ck)
- data_train = create_dataset(dataset_direct, filename, batch, columns_list,
- num_consumer)
-
-
- model.train(epoch,
- data_train,
- callbacks=[
- ckpoint_cb,
- LossMonitor(per_print_times=181),
- TimeMonitor()
- ],
- dataset_sink_mode=True)
-
-
- if __name__ == "__main__":
- set_seed(1)
- parser = argparse.ArgumentParser(description='Train model')
- parser.add_argument('--device_id',
- type=int,
- help='device ID',
- default=None)
-
- args = parser.parse_args()
-
- if args.device_id is not None:
- context.set_context(device_target='Ascend',
- mode=context.GRAPH_MODE,
- device_id=args.device_id)
- else:
- context.set_context(device_target='Ascend',
- mode=context.GRAPH_MODE,
- device_id=cfg.device_id)
-
- context.set_context(enable_auto_mixed_precision=cfg.mixed_precision)
- network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
- kernel_size=[3, 3, 3, 3, 3],
- padding=[0] * 5,
- maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
- has_bias=True)
-
- if cfg.pre_trained:
- param_dict = load_checkpoint(cfg.checkpoint_path + '/' +
- cfg.model_name)
- load_param_into_net(network, param_dict)
-
- net_loss = BCELoss()
-
- network.set_train(True)
- net_opt = nn.Adam(params=network.trainable_params(),
- learning_rate=cfg.lr,
- loss_scale=cfg.loss_scale)
-
- loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale,
- drop_overflow_update=False)
- net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager)
-
- train(model=net_model,
- dataset_direct=cfg.data_dir,
- filename=cfg.train_filename,
- columns_list=['feature', 'label'],
- num_consumer=cfg.num_consumer,
- batch=cfg.batch_size,
- epoch=cfg.epoch_size,
- save_checkpoint_steps=cfg.save_step,
- keep_checkpoint_max=cfg.keep_checkpoint_max,
- prefix=cfg.prefix,
- directory=cfg.checkpoint_path + "_{}".format(cfg.device_id))
- print("train success")
|