# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import OrderedDict
from .module import Module
[文档]class Sequential(Module):
r"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
import megengine.functional as F
from collections import OrderedDict
batch_size = 64
data = mge.tensor(np.zeros((batch_size, 28 * 28)), dtype=np.float32)
label = mge.tensor(np.zeros(batch_size,), dtype=np.int32)
net0 = M.Sequential(
M.Linear(28 * 28, 320),
M.Linear(320, 10)
)
pred0 = net0(data)
modules = OrderedDict()
modules["fc0"] = M.Linear(28 * 28, 320)
modules["fc1"] = M.Linear(320, 10)
net1 = M.Sequential(modules)
pred1 = net1(data)
"""
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.layer_keys = []
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
# self.add_module(key, module)
setattr(self, key, module)
self.layer_keys.append(key)
else:
for idx, module in enumerate(args):
# self.add_module(str(idx), module)
setattr(self, str(idx), module)
self.layer_keys.append(str(idx))
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(
OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx]))
)
else:
return getattr(self, self.layer_keys[idx])
def __setitem__(self, idx, module):
key = self.layer_keys[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.layer_keys[idx]:
delattr(self, key)
del self.layer_keys[idx]
else:
delattr(self, self.layer_keys[idx])
del self.layer_keys[idx]
def __len__(self):
return len(self.layer_keys)
def __iter__(self):
return iter(self.layer_values)
@property
def layer_values(self):
return [getattr(self, key) for key in self.layer_keys]
[文档] def forward(self, inp):
# avoid layer_values as a name prefix, see Module.__getattribute__
for layer in [getattr(self, key) for key in self.layer_keys]:
inp = layer(inp)
return inp