|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Learning rate utilities."""
-
- def linear_warmup(warmup_steps, current_step):
- return min([1.0, float(current_step)/float(warmup_steps)])
-
- def rsqrt_decay(warmup_steps, current_step):
- return float(max([current_step, warmup_steps])) ** -0.5
-
- def rsqrt_hidden(hidden_size):
- return float(hidden_size) ** -0.5
-
- def create_dynamic_lr(schedule, training_steps, learning_rate, warmup_steps, hidden_size,
- start_decay_step=0, min_lr=0.):
- """
- Generate dynamic learning rate.
- """
- if start_decay_step < warmup_steps:
- start_decay_step = warmup_steps
- lr = []
- for current_step in range(1, training_steps+1):
- cur_lr = 1.0
- for name in schedule.split("*"):
- if name == "constant":
- cur_lr *= float(learning_rate)
- elif name == "rsqrt_hidden":
- cur_lr *= rsqrt_hidden(hidden_size)
- elif name == "linear_warmup":
- cur_lr *= linear_warmup(warmup_steps, current_step)
- elif name == "rsqrt_decay":
- cur_lr *= rsqrt_decay(warmup_steps, current_step-start_decay_step+warmup_steps)
- else:
- raise ValueError("unknown learning rate schedule")
- if warmup_steps < current_step < start_decay_step:
- cur_lr = lr[-1]
- if current_step > warmup_steps:
- cur_lr = max([cur_lr, min_lr])
- lr.append(cur_lr)
- return lr
|