Skip to main content

Weight Initialization

在训练过程中,适当的初始化策略有利于加快训练速度或获得更高的性能。MMCV 提供了一些常用的初始化模块的方法,如 nn.Conv2d。MMdetection 中的模型初始化主要使用 init_cfg。用户可以通过以下两个步骤来初始化模型。

  1. 在 model_cfg 中为模型或其组件定义 init_cfg,但子组件的 init_cfg 具有更高的优先级,并将覆盖父模块的 init_cfg。
  2. 像往常一样建立模型,但调用 model.init_weights() 方法,模型参数将被初始化为配置。
info

对于高级的工作流的初始化,在 mmdetection 的调用顺序为:

model_cfg(init_cfg) -> build_from_cfg -> model -> init_weight() -> initialize(self, self.init_cfg) -> children's init_weight()

数据结构描述

它是dict或list[dict],并包含以下键和值。

  • type(str),包含 INTIALIZERS 中的初始化器名称,后面是初始化器的参数。
  • layer(str 或 list[str]),包含 Pytorch 或 MMCV 中带有可学习参数的基本层名称,将被初始化,例如 'Conv2d', 'DeformConv2d'
  • override (dict 或 list[dict]),包含不继承自 BaseModule 的子模块,其初始化配置与 'layer' 键中的其他层不同。type 中定义的初始化器将对 layer 中定义的所有层起作用,所以如果子模块不是 BaseModule 的派生类,但可以用 layer 中的相同方式初始化,就不需要使用 override了。
    • type,然后是初始化器的参数。
    • name,表示将被初始化的子模块。

初始化的参数

mmcv.runner.BaseModulemmdet.models 继承一个新模型 这里我们展示一个 FooModel 的例子。

import torch.nn as nn
from mmcv.runner import BaseModule

class FooModel(BaseModule)
def __init__(self,
arg1,
arg2,
init_cfg=None):
super(FooModel, self).__init__(init_cfg)
...
  • 使用 init_cfg 初始化模型

    import torch.nn as nn
    from mmcv.runner import BaseModule
    # or directly inherit mmdet models

    class FooModel(BaseModule)
    def __init__(self,
    arg1,
    arg2,
    init_cfg=XXX):
    super(FooModel, self).__init__(init_cfg)
    ...
  • mmcv.Sequentialmmcv.ModuleList 代码中直接使用 init_cfg 来初始化模型

    from mmcv.runner import BaseModule, ModuleList

    class FooModel(BaseModule)
    def __init__(self,
    arg1,
    arg2,
    init_cfg=None):
    super(FooModel, self).__init__(init_cfg)
    ...
    self.conv1 = ModuleList(init_cfg=XXX)
  • 通过使用配置文件中的 init_cfg 来初始化模型

    model = dict(
    ...
    model = dict(
    type='FooModel',
    arg1=XXX,
    arg2=XXX,
    init_cfg=XXX),
    ...

init_cfg 的使用方式

  1. layer 初始化模型

    如果我们只定义 layer,它只是在 layer 中初始化层。

    注意layer 的值是带有 Pytorch 属性权重和偏向的类名,(所以不支持 MultiheadAttention 层)。

    • 定义 layer,用于以相同的配置初始化模块。

      init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d', 'Linear'], val=1)
      # initialize whole module with same configuration
    • 定义 layer,用于初始化具有不同配置的层。

      init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
      dict(type='Constant', layer='Conv2d', val=2),
      dict(type='Constant', layer='Linear', val=3)]
      # nn.Conv1d will be initialized with dict(type='Constant', val=1)
      # nn.Conv2d will be initialized with dict(type='Constant', val=2)
      # nn.Linear will be initialized with dict(type='Constant', val=3)
  2. override 来初始化模型

    • 当用属性名称初始化某些特定的部分时,我们可以使用 override 键,override中的值将忽略 init_cfg 中的值。

      # layers:
      # self.feat = nn.Conv1d(3, 1, 3)
      # self.reg = nn.Conv2d(3, 3, 3)
      # self.cls = nn.Linear(1,2)

      init_cfg = dict(type='Constant',
      layer=['Conv1d','Conv2d'], val=1, bias=2,
      override=dict(type='Constant', name='reg', val=3, bias=4))
      # self.feat and self.cls will be initialized with dict(type='Constant', val=1, bias=2)
      # The module called 'reg' will be initialized with dict(type='Constant', val=3, bias=4)
    • 如果 init_cfg 中的 layer 是 None,那么只有 override 中名字的子模块会被初始化,override 中的 type 和其他 args 可以省略。

      # layers:
      # self.feat = nn.Conv1d(3, 1, 3)
      # self.reg = nn.Conv2d(3, 3, 3)
      # self.cls = nn.Linear(1,2)

      init_cfg = dict(type='Constant', val=1, bias=2, override=dict(name='reg'))

      # self.feat and self.cls will be initialized by Pytorch
      # The module called 'reg' will be initialized with dict(type='Constant', val=1, bias=2)
    • 如果我们不定义 layeroverride,它将不会初始化任何东西。

    • 不正确的使用方式

      # It is invalid that override don't have name key
      init_cfg = dict(type='Constant', layer=['Conv1d','Conv2d'], val=1, bias=2,
      override=dict(type='Constant', val=3, bias=4))

      # It is also invalid that override has name and other args except type
      init_cfg = dict(type='Constant', layer=['Conv1d','Conv2d'], val=1, bias=2,
      override=dict(name='reg', val=3, bias=4))
  3. 用预训练模型初始化模型

    init_cfg = dict(type='Pretrained',
    checkpoint='torchvision://resnet50')