megengine.optimizer.lr_scheduler 源代码

# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta

from .optimizer import Optimizer


[文档]class LRScheduler(metaclass=ABCMeta): r""" Base class for all learning rate based schedulers. :param optimizer: wrapped optimizer. :param current_epoch: the index of current epoch. Default: -1 """ def __init__( # pylint: disable=too-many-branches self, optimizer: Optimizer, current_epoch: int = -1 ): if not isinstance(optimizer, Optimizer): raise TypeError( "optimizer argument given to the lr_scheduler should be Optimizer" ) self.optimizer = optimizer self.current_epoch = current_epoch if current_epoch == -1: for group in self.optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) else: for i, group in enumerate(optimizer.param_groups): if "initial_lr" not in group: raise KeyError( "param 'initial_lr' is not specified in " "param_groups[{}] when resuming an optimizer".format(i) ) self.base_lrs = list( map(lambda group: group["initial_lr"], self.optimizer.param_groups) ) self.step()
[文档] def state_dict(self): r""" Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ raise NotImplementedError
[文档] def load_state_dict(self, state_dict): r""" Loads the schedulers state. :type state_dict: dict :param state_dict: scheduler state. """ raise NotImplementedError
[文档] def get_lr(self): r""" Compute current learning rate for the scheduler. """ raise NotImplementedError
[文档] def step(self, epoch=None): if epoch is None: self.current_epoch += 1 else: self.current_epoch = epoch values = self.get_lr() for param_group, lr in zip(self.optimizer.param_groups, values): param_group["lr"] = lr