Source code for paddlespeech.s2t.training.scheduler

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Any
from typing import Dict
from typing import Text
from typing import Union

import paddle
from paddle.optimizer.lr import LRScheduler
from typeguard import check_argument_types

from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.dynamic_import import instance_class
from paddlespeech.s2t.utils.log import Log

__all__ = ["WarmupLR", "LRSchedulerFactory"]

logger = Log(__name__).getlog()

SCHEDULER_DICT = {
    "noam": "paddle.optimizer.lr:NoamDecay",
    "expdecaylr": "paddle.optimizer.lr:ExponentialDecay",
    "piecewisedecay": "paddle.optimizer.lr:PiecewiseDecay",
}


def register_scheduler(cls):
    """Register scheduler."""
    alias = cls.__name__.lower()
    SCHEDULER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
    return cls


[docs]@register_scheduler class WarmupLR(LRScheduler): """The WarmupLR scheduler This scheduler is almost same as NoamLR Scheduler except for following difference: NoamLR: lr = optimizer.lr * model_size ** -0.5 * min(step ** -0.5, step * warmup_step ** -1.5) WarmupLR: lr = optimizer.lr * warmup_step ** 0.5 * min(step ** -0.5, step * warmup_step ** -1.5) Note that the maximum lr equals to optimizer.lr in this scheduler. """ def __init__(self, warmup_steps: Union[int, float]=25000, learning_rate=1.0, last_epoch=-1, verbose=False, **kwargs): assert check_argument_types() self.warmup_steps = warmup_steps super().__init__(learning_rate, last_epoch, verbose) def __repr__(self): return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, last_epoch={self.last_epoch})"
[docs] def get_lr(self): # self.last_epoch start from zero step_num = self.last_epoch + 1 return self.base_lr * self.warmup_steps**0.5 * min( step_num**-0.5, step_num * self.warmup_steps**-1.5)
[docs] def set_step(self, step: int=None): ''' It will update the learning rate in optimizer according to current ``epoch`` . The new learning rate will take effect on next ``optimizer.step`` . Args: step (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. Returns: None ''' self.step(epoch=step)
@register_scheduler class ConstantLR(LRScheduler): """ Args: learning_rate (float): The initial learning rate. It is a python float number. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . Returns: ``ConstantLR`` instance to schedule learning rate. """ def __init__(self, learning_rate, last_epoch=-1, verbose=False): super().__init__(learning_rate, last_epoch, verbose) def get_lr(self): return self.base_lr @register_scheduler class NewBobScheduler(LRScheduler): """Scheduler with new-bob technique, used for LR annealing. The learning rate is annealed based on the validation performance. In particular: if (past_loss-current_loss)/past_loss< impr_threshold: lr=lr * annealing_factor. Arguments --------- initial_value : float The initial hyperparameter value. annealing_factor : float It is annealing factor used in new_bob strategy. improvement_threshold : float It is the improvement rate between losses used to perform learning annealing in new_bob strategy. patient : int When the annealing condition is violated patient times, the learning rate is finally reduced. Example ------- >>> scheduler = NewBobScheduler(initial_value=1.0) >>> scheduler(metric_value=10.0) (1.0, 1.0) >>> scheduler(metric_value=2.0) (1.0, 1.0) >>> scheduler(metric_value=2.5) (1.0, 0.5) """ def __init__( self, learning_rate, last_epoch=-1, verbose=False, annealing_factor=0.5, improvement_threshold=0.0025, patient=0, ): self.hyperparam_value = learning_rate self.annealing_factor = annealing_factor self.improvement_threshold = improvement_threshold self.patient = patient self.metric_values = [] self.current_patient = self.patient super().__init__(learning_rate, last_epoch, verbose) def step(self, metric_value=None): """ ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` . The new learning rate will take effect on next ``optimizer.step`` . Args: epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. Returns: None """ if metric_value is None: self.last_epoch += 1 self.last_lr = self.hyperparam_value else: self.last_epoch += 1 self.last_lr = self.get_lr(metric_value) if self.verbose: print('Epoch {}: {} set learning rate to {}.'.format( self.last_epoch, self.__class__.__name__, self.last_lr)) def get_lr(self, metric_value): """Returns the current and new value for the hyperparameter. Arguments --------- metric_value : int A number for determining whether to change the hyperparameter value. """ new_value = self.hyperparam_value if len(self.metric_values) > 0: prev_metric = self.metric_values[-1] # Update value if improvement too small and patience is 0 if prev_metric == 0: # Prevent division by zero improvement = 0 else: improvement = (prev_metric - metric_value) / prev_metric if improvement < self.improvement_threshold: if self.current_patient == 0: new_value *= self.annealing_factor self.current_patient = self.patient else: self.current_patient -= 1 # Store relevant info self.metric_values.append(metric_value) self.hyperparam_value = new_value return new_value def save(self): """Saves the current metrics on the specified path.""" data = { "current_epoch_index": self.last_epoch, "hyperparam_value": self.hyperparam_value, "metric_values": self.metric_values, "current_patient": self.current_patient } return data def load(self, data): """Loads the needed information.""" self.last_epoch = data["current_epoch_index"] self.hyperparam_value = data["hyperparam_value"] self.metric_values = data["metric_values"] self.current_patient = data["current_patient"] def dynamic_import_scheduler(module): """Import Scheduler class dynamically. Args: module (str): module_name:class_name or alias in `SCHEDULER_DICT` Returns: type: Scheduler class """ module_class = dynamic_import(module, SCHEDULER_DICT) assert issubclass(module_class, LRScheduler), f"{module} does not implement LRScheduler" return module_class
[docs]class LRSchedulerFactory():
[docs] @classmethod def from_args(cls, name: str, args: Dict[Text, Any]): module_class = dynamic_import_scheduler(name.lower()) return instance_class(module_class, args)