mlx/python/mlx/nn/layers/containers.py

25 lines
618 B
Python
Raw Normal View History

2023-12-01 03:12:53 +08:00
# Copyright © 2023 Apple Inc.
2023-11-30 02:30:41 +08:00
from mlx.nn.layers.base import Module
class Sequential(Module):
"""A layer that calls the passed callables in order.
We can pass either modules or plain callables to the Sequential module. If
our functions have learnable parameters they should be implemented as
``nn.Module`` instances.
Args:
modules (tuple of Callables): The modules to call in order
"""
def __init__(self, *modules):
super().__init__()
self.layers = list(modules)
def __call__(self, x):
for m in self.layers:
x = m(x)
return x