| @@ -40,6 +40,8 @@ class ConvertNetUntils(): | |||||
| if subcell.activation_flag: | if subcell.activation_flag: | ||||
| act_class = subcell.activation.__class__.__name__ | act_class = subcell.activation.__class__.__name__ | ||||
| act_name = act_class.lower() | act_name = act_class.lower() | ||||
| if act_name == "fastgelu": | |||||
| act_name = "fast_gelu" | |||||
| if subcell.out_channels == 1001: | if subcell.out_channels == 1001: | ||||
| new_subcell = nn.Dense_Thor(in_channels=subcell.in_channels, | new_subcell = nn.Dense_Thor(in_channels=subcell.in_channels, | ||||
| out_channels=subcell.out_channels, | out_channels=subcell.out_channels, | ||||
| @@ -18,7 +18,7 @@ network config setting, will be used in train.py and eval.py | |||||
| from easydict import EasyDict as ed | from easydict import EasyDict as ed | ||||
| # config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional. | # config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional. | ||||
| cfg = ed({ | cfg = ed({ | ||||
| 'optimizer': 'Thor', | |||||
| 'optimizer': 'Momentum', | |||||
| }) | }) | ||||
| # config for resent50, cifar10 | # config for resent50, cifar10 | ||||
| @@ -18,7 +18,7 @@ import argparse | |||||
| import ast | import ast | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.nn.optim.momentum import Momentum, THOR | |||||
| from mindspore.nn.optim import Momentum, THOR | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.train_thor import ConvertModelUtils | from mindspore.train.train_thor import ConvertModelUtils | ||||
| @@ -100,8 +100,9 @@ def _get_optimizer(args_opt, network): | |||||
| optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) | optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) | ||||
| elif cfg.optimizer == "Thor": | elif cfg.optimizer == "Thor": | ||||
| from src.utils import get_bert_thor_lr, get_bert_thor_damping | from src.utils import get_bert_thor_lr, get_bert_thor_damping | ||||
| lr = get_bert_thor_lr() | |||||
| damping = get_bert_thor_damping() | |||||
| lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps) | |||||
| damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power, | |||||
| cfg.Thor.damping_total_steps) | |||||
| split_indices = None | split_indices = None | ||||
| if bert_net_cfg.num_hidden_layers == 12: | if bert_net_cfg.num_hidden_layers == 12: | ||||
| if bert_net_cfg.use_relative_positions: | if bert_net_cfg.use_relative_positions: | ||||
| @@ -49,6 +49,14 @@ cfg = edict({ | |||||
| 'momentum': 0.9, | 'momentum': 0.9, | ||||
| }), | }), | ||||
| 'Thor': edict({ | 'Thor': edict({ | ||||
| 'lr_max': 0.0034, | |||||
| 'lr_min': 3.244e-5, | |||||
| 'lr_power': 1.0, | |||||
| 'lr_total_steps': 30000, | |||||
| 'damping_max': 5e-2, | |||||
| 'damping_min': 1e-6, | |||||
| 'damping_power': 1.0, | |||||
| 'damping_total_steps': 30000, | |||||
| 'momentum': 0.9, | 'momentum': 0.9, | ||||
| 'weight_decay': 5e-4, | 'weight_decay': 5e-4, | ||||
| 'loss_scale': 1.0, | 'loss_scale': 1.0, | ||||
| @@ -22,7 +22,6 @@ import math | |||||
| import collections | import collections | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| @@ -107,10 +106,11 @@ class LossCallBack(Callback): | |||||
| percent = 1 | percent = 1 | ||||
| epoch_num -= 1 | epoch_num -= 1 | ||||
| print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" | print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" | ||||
| .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs))) | |||||
| .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)), | |||||
| flush=True) | |||||
| else: | else: | ||||
| print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | ||||
| str(cb_params.net_outputs))) | |||||
| str(cb_params.net_outputs)), flush=True) | |||||
| def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): | def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): | ||||
| """ | """ | ||||
| @@ -220,22 +220,13 @@ def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps | |||||
| return learning_rate | return learning_rate | ||||
| def get_bert_thor_lr(): | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=3.244018779068399e-05, | |||||
| lr_max=0.0034022148941459055, warmup_steps=0, total_steps=30000, poly_power=1) | |||||
| else: | |||||
| learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=1.7e-3, warmup_steps=0, | |||||
| total_steps=30000, poly_power=1) | |||||
| def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000): | |||||
| learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0, | |||||
| total_steps=lr_total_steps, poly_power=lr_power) | |||||
| return Tensor(learning_rate) | return Tensor(learning_rate) | ||||
| def get_bert_thor_damping(): | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=5e-2, warmup_steps=0, total_steps=30000, | |||||
| poly_power=1) | |||||
| else: | |||||
| damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.5e-2, warmup_steps=0, | |||||
| total_steps=30000, poly_power=1) | |||||
| def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000): | |||||
| damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0, | |||||
| total_steps=damping_total_steps, poly_power=damping_power) | |||||
| return Tensor(damping) | return Tensor(damping) | ||||