2023-12-01 03:12:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
from mlx.nn.layers.base import Module
|
|
|
|
|
|
|
|
|
|
|
|
class Sequential(Module):
|
|
|
|
"""A layer that calls the passed callables in order.
|
|
|
|
|
|
|
|
We can pass either modules or plain callables to the Sequential module. If
|
|
|
|
our functions have learnable parameters they should be implemented as
|
|
|
|
``nn.Module`` instances.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
modules (tuple of Callables): The modules to call in order
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, *modules):
|
|
|
|
super().__init__()
|
|
|
|
self.layers = list(modules)
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
for m in self.layers:
|
|
|
|
x = m(x)
|
|
|
|
return x
|