mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into main
This commit is contained in:
@@ -37,6 +37,7 @@ from mlx.nn.layers.activations import (
|
||||
relu,
|
||||
relu6,
|
||||
selu,
|
||||
sigmoid,
|
||||
silu,
|
||||
softmax,
|
||||
softplus,
|
||||
@@ -67,3 +68,4 @@ from mlx.nn.layers.transformer import (
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
from mlx.nn.layers.upsample import Upsample
|
||||
|
||||
@@ -18,7 +18,7 @@ def _make_activation_module(f):
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def sigmoid(x):
|
||||
r"""Applies the element-wise function:
|
||||
r"""Applies the sigmoid function.
|
||||
|
||||
.. math::
|
||||
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
||||
@@ -142,11 +142,11 @@ def log_sigmoid(x):
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gelu(x):
|
||||
def gelu(x) -> mx.array:
|
||||
r"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
.. math::
|
||||
\\textrm{GELU}(x) = x * \Phi(x)
|
||||
\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
@@ -185,11 +185,15 @@ def gelu_fast_approx(x):
|
||||
|
||||
.. math::
|
||||
|
||||
x = x \sigma\left(1.773 x\right)
|
||||
x = x \sigma\left(1.702 x\right)
|
||||
|
||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
|
||||
References:
|
||||
- https://github.com/hendrycks/GELUs
|
||||
- https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * mx.sigmoid(1.773 * x)
|
||||
return x * mx.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
@@ -199,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
@@ -260,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def mish(x: mx.array) -> mx.array:
|
||||
r"""Applies the Mish function, element-wise.
|
||||
|
||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||
|
||||
Reference: https://arxiv.org/abs/1908.08681
|
||||
@@ -297,7 +302,7 @@ class GLU(Module):
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
|
||||
@@ -7,6 +7,42 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
@@ -98,10 +134,13 @@ class Module(dict):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||
super(Module, self).__getattribute__(key)
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
if isinstance(val, (mx.array, dict, list, tuple)):
|
||||
self[key] = val
|
||||
else:
|
||||
super(Module, self).__setattr__(key, val)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
@@ -245,31 +284,11 @@ class Module(dict):
|
||||
is_leaf_fn = is_leaf_fn or (
|
||||
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||
)
|
||||
|
||||
def unwrap(vk, v):
|
||||
if is_leaf_fn(self, vk, v):
|
||||
return map_fn(v)
|
||||
|
||||
if isinstance(v, Module):
|
||||
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
||||
|
||||
if isinstance(v, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{vk}.{k}"
|
||||
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
||||
return nd
|
||||
|
||||
if isinstance(v, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(v):
|
||||
tk = f"{vk}.{i}"
|
||||
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
||||
return {
|
||||
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in self.items()
|
||||
if filter_fn(self, k, v)
|
||||
}
|
||||
|
||||
def parameters(self):
|
||||
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||
|
||||
205
python/mlx/nn/layers/upsample.py
Normal file
205
python/mlx/nn/layers/upsample.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
M = int(scale * N)
|
||||
if align_corners:
|
||||
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
|
||||
else:
|
||||
step = 1 / scale
|
||||
start = ((M - 1) * step - N + 1) / 2
|
||||
indices = mx.arange(M, dtype=mx.float32) * step - start
|
||||
indices = mx.clip(indices, 0, N - 1)
|
||||
shape = [1] * ndims
|
||||
shape[dim] = -1
|
||||
|
||||
return indices.reshape(shape)
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||
indices_l = mx.floor(indices)
|
||||
indices_r = mx.ceil(indices)
|
||||
weight = indices - indices_l
|
||||
weight = mx.expand_dims(weight, -1)
|
||||
|
||||
return (
|
||||
(indices_l.astype(mx.int32), 1 - weight),
|
||||
(indices_r.astype(mx.int32), weight),
|
||||
)
|
||||
|
||||
|
||||
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
|
||||
# Integer scale_factors means we can simply expand-broadcast and reshape
|
||||
if tuple(map(int, scale_factor)) == scale_factor:
|
||||
shape = list(x.shape)
|
||||
for d in range(dims):
|
||||
shape.insert(2 + 2 * d, 1)
|
||||
x = x.reshape(shape)
|
||||
for d in range(dims):
|
||||
shape[2 + 2 * d] = int(scale_factor[d])
|
||||
x = mx.broadcast_to(x, shape)
|
||||
for d in range(dims):
|
||||
shape[d + 1] *= shape[d + 2]
|
||||
shape.pop(d + 2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
else:
|
||||
B, *N, C = x.shape
|
||||
indices = [slice(None)]
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_nearest_indices(n, s, i, dims))
|
||||
indices = tuple(indices)
|
||||
|
||||
return x[indices]
|
||||
|
||||
|
||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
|
||||
B, *N, C = x.shape
|
||||
|
||||
# Compute the sampling grid
|
||||
indices = []
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_linear_indices(n, s, align_corners, i, dims))
|
||||
|
||||
# Sample and compute the weights
|
||||
samples = []
|
||||
weights = []
|
||||
for idx_weight in product(*indices):
|
||||
idx, weight = zip(*idx_weight)
|
||||
samples.append(x[(slice(None),) + idx])
|
||||
weights.append(reduce(operator.mul, weight))
|
||||
|
||||
# Interpolate
|
||||
return sum(wi * xi for wi, xi in zip(weights, samples))
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
r"""Upsample the input signal spatially.
|
||||
|
||||
The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
|
||||
2``. The first is the batch dimension and the last is the feature
|
||||
dimension.
|
||||
|
||||
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
||||
4D with 2 and so on and so forth.
|
||||
|
||||
There are two upsampling algorithms implemented nearest neighbor upsampling
|
||||
and linear interpolation. Both can be applied to any number of spatial
|
||||
dimensions and the linear interpolation will be bilinear, trilinear etc
|
||||
when applied to more than one spatial dimension.
|
||||
|
||||
.. note::
|
||||
When using one of the linear interpolation modes the ``align_corners``
|
||||
argument changes how the corners are treated in the input image. If
|
||||
``align_corners=True`` then the top and left edge of the input and
|
||||
output will be matching as will the bottom right edge.
|
||||
|
||||
Parameters:
|
||||
scale_factor (float or tuple): The multiplier for the spatial size.
|
||||
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
||||
Otherwise, the number of scale factors provided must match the
|
||||
number of spatial dimensions.
|
||||
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
|
||||
``"linear"``. Default: ``"nearest"``.
|
||||
align_corners (bool, optional): Changes the way the corners are treated
|
||||
during ``"linear"`` upsampling. See the note above and the
|
||||
examples below for more details. Default: ``False``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
|
||||
>>> x
|
||||
array([[[[1],
|
||||
[2]],
|
||||
[[3],
|
||||
[4]]]], dtype=int32)
|
||||
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
>>> n(x).squeeze()
|
||||
array([[1, 1, 2, 2],
|
||||
[1, 1, 2, 2],
|
||||
[3, 3, 4, 4],
|
||||
[3, 3, 4, 4]], dtype=int32)
|
||||
>>> b = nn.Upsample(scale_factor=2, mode='linear')
|
||||
>>> b(x).squeeze()
|
||||
array([[1, 1.25, 1.75, 2],
|
||||
[1.5, 1.75, 2.25, 2.5],
|
||||
[2.5, 2.75, 3.25, 3.5],
|
||||
[3, 3.25, 3.75, 4]], dtype=float32)
|
||||
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
|
||||
>>> b(x).squeeze()
|
||||
array([[1, 1.33333, 1.66667, 2],
|
||||
[1.66667, 2, 2.33333, 2.66667],
|
||||
[2.33333, 2.66667, 3, 3.33333],
|
||||
[3, 3.33333, 3.66667, 4]], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale_factor: Union[float, Tuple],
|
||||
mode: Literal["nearest", "linear"] = "nearest",
|
||||
align_corners: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if mode not in ["nearest", "linear"]:
|
||||
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
||||
if isinstance(scale_factor, (list, tuple)):
|
||||
self.scale_factor = tuple(map(float, scale_factor))
|
||||
else:
|
||||
self.scale_factor = float(scale_factor)
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
return (
|
||||
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
|
||||
f"align_corners={self.align_corners}"
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
dims = x.ndim - 2
|
||||
if dims <= 0:
|
||||
raise ValueError(
|
||||
f"[Upsample] The input should have at least 1 spatial "
|
||||
f"dimension which means it should be at least 3D but "
|
||||
f"{x.ndim}D was provided"
|
||||
)
|
||||
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, tuple):
|
||||
if len(scale_factor) != dims:
|
||||
raise ValueError(
|
||||
f"[Upsample] One scale per spatial dimension is required but "
|
||||
f"scale_factor={scale_factor} and the number of spatial "
|
||||
f"dimensions were {dims}"
|
||||
)
|
||||
else:
|
||||
scale_factor = (scale_factor,) * dims
|
||||
|
||||
if self.mode == "nearest":
|
||||
return upsample_nearest(x, scale_factor)
|
||||
|
||||
else:
|
||||
return upsample_linear(x, scale_factor, self.align_corners)
|
||||
@@ -1,11 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Callable, List
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def exponential_decay(init: float, decay_rate: float):
|
||||
def exponential_decay(init: float, decay_rate: float) -> Callable:
|
||||
r"""Make an exponential decay scheduler.
|
||||
|
||||
Args:
|
||||
@@ -30,7 +31,7 @@ def exponential_decay(init: float, decay_rate: float):
|
||||
return schedule
|
||||
|
||||
|
||||
def step_decay(init: float, decay_rate: float, step_size: int):
|
||||
def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
|
||||
r"""Make a step decay scheduler.
|
||||
|
||||
Args:
|
||||
@@ -57,7 +58,7 @@ def step_decay(init: float, decay_rate: float, step_size: int):
|
||||
return schedule
|
||||
|
||||
|
||||
def cosine_decay(init: float, decay_steps: int):
|
||||
def cosine_decay(init: float, decay_steps: int) -> Callable:
|
||||
r"""Make a cosine decay scheduler.
|
||||
|
||||
Args:
|
||||
@@ -84,3 +85,73 @@ def cosine_decay(init: float, decay_steps: int):
|
||||
return init * decay
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
|
||||
r"""Join multiple schedules to create a new schedule.
|
||||
|
||||
Args:
|
||||
schedules (list(Callable)): A list of schedules. Schedule :math:`i+1`
|
||||
receives a step count indicating the number of steps since
|
||||
the :math:`i`-th boundary.
|
||||
boundaries (list(int)): A list of integers of length ``len(schedules) - 1``
|
||||
that indicates when to transition between schedules.
|
||||
|
||||
Example:
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
|
||||
>>> cosine = optim.cosine_decay(1e-1, 200)
|
||||
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
|
||||
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
>>> for _ in range(12): optimizer.update({}, {})
|
||||
...
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0999938, dtype=float32)
|
||||
"""
|
||||
if len(schedules) == 0:
|
||||
raise ValueError("Must provide at least 1 schedule to join.")
|
||||
|
||||
if len(schedules) != len(boundaries) + 1:
|
||||
raise ValueError(
|
||||
f"Received {len(boundaries)} boundaries but "
|
||||
f"expected {len(schedules) - 1}."
|
||||
)
|
||||
|
||||
def schedule(step):
|
||||
output = schedules[0](step)
|
||||
for boundary, schedule in zip(boundaries, schedules[1:]):
|
||||
output = mx.where(step < boundary, output, schedule(step - boundary))
|
||||
return output
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
||||
r"""Make a linear scheduler.
|
||||
|
||||
Args:
|
||||
init (float): Initial value.
|
||||
end (float): Final value.
|
||||
steps (int): Number of steps to apply the schedule over. The value is
|
||||
``end`` for any steps beyond ``steps``.
|
||||
|
||||
Example:
|
||||
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
|
||||
>>> optimizer = optim.Adam(learning_rate=warmup)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
>>> for _ in range(101): optimizer.update({}, {})
|
||||
...
|
||||
>>> optimizer.learning_rate
|
||||
array(0.1, dtype=float32)
|
||||
"""
|
||||
if steps < 1:
|
||||
raise ValueError(f"steps must be greater than 0, but got {steps}.")
|
||||
|
||||
def step_fn(step):
|
||||
step = mx.minimum(step, steps)
|
||||
return step * ((end - init) / steps) + init
|
||||
|
||||
return step_fn
|
||||
|
||||
@@ -14,6 +14,7 @@ pybind11_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/pybind11_numpy_fp16.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
@@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
||||
shape.push_back(np_array.shape(i));
|
||||
}
|
||||
|
||||
// Get dtype
|
||||
auto type = np_array.dtype();
|
||||
|
||||
// Copy data and make array
|
||||
if (type.is(py::dtype::of<int>())) {
|
||||
if (py::isinstance<py::array_t<int32_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int32_t>(
|
||||
np_array, shape, dtype.value_or(int32));
|
||||
} else if (type.is(py::dtype::of<uint32_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint32_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint32_t>(
|
||||
np_array, shape, dtype.value_or(uint32));
|
||||
} else if (type.is(py::dtype::of<bool>())) {
|
||||
} else if (py::isinstance<py::array_t<bool>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<bool>(
|
||||
np_array, shape, dtype.value_or(bool_));
|
||||
} else if (type.is(py::dtype::of<double>())) {
|
||||
} else if (py::isinstance<py::array_t<double>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<double>(
|
||||
np_array, shape, dtype.value_or(float32));
|
||||
} else if (type.is(py::dtype::of<float>())) {
|
||||
} else if (py::isinstance<py::array_t<float>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<float>(
|
||||
np_array, shape, dtype.value_or(float32));
|
||||
} else if (type.is(py::dtype("float16"))) {
|
||||
} else if (py::isinstance<py::array_t<float16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<float>(
|
||||
np_array, shape, dtype.value_or(float16));
|
||||
} else if (type.is(py::dtype::of<uint8_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint8_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint8_t>(
|
||||
np_array, shape, dtype.value_or(uint8));
|
||||
} else if (type.is(py::dtype::of<uint16_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint16_t>(
|
||||
np_array, shape, dtype.value_or(uint16));
|
||||
} else if (type.is(py::dtype::of<uint64_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint64_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint64_t>(
|
||||
np_array, shape, dtype.value_or(uint64));
|
||||
} else if (type.is(py::dtype::of<int8_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int8_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int8_t>(
|
||||
np_array, shape, dtype.value_or(int8));
|
||||
} else if (type.is(py::dtype::of<int16_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int16_t>(
|
||||
np_array, shape, dtype.value_or(int16));
|
||||
} else if (type.is(py::dtype::of<int64_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int64_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int64_t>(
|
||||
np_array, shape, dtype.value_or(int64));
|
||||
} else if (type.is(py::dtype::of<std::complex<float>>())) {
|
||||
} else if (py::isinstance<py::array_t<std::complex<float>>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||
np_array, shape, dtype.value_or(complex64));
|
||||
} else if (type.is(py::dtype::of<std::complex<double>>())) {
|
||||
} else if (py::isinstance<py::array_t<std::complex<double>>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||
np_array, shape, dtype.value_or(complex64));
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
|
||||
msg << "Cannot convert numpy array of type " << np_array.dtype()
|
||||
<< " to mlx array.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,18 +5,88 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(py::module_& m) {
|
||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def("is_available", &metal::is_available);
|
||||
metal.def(
|
||||
"cache_enabled",
|
||||
&metal::cache_enabled,
|
||||
"check if metal buffer cache is enabled, default is true");
|
||||
"is_available",
|
||||
&metal::is_available,
|
||||
R"pbdoc(
|
||||
Check if the Metal back-end is available.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_cache_enabled",
|
||||
&metal::set_cache_enabled,
|
||||
"enable or disable metal buffer cache");
|
||||
"get_active_memory",
|
||||
&metal::get_active_memory,
|
||||
R"pbdoc(
|
||||
Get the actively used memory in bytes.
|
||||
|
||||
Note, this will not always match memory use reported by the system because
|
||||
it does not include cached memory buffers.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_peak_memory",
|
||||
&metal::get_peak_memory,
|
||||
R"pbdoc(
|
||||
Get the peak amount of used memory in bytes.
|
||||
|
||||
The maximum memory used is recorded from the beginning of the program
|
||||
execution.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_cache_memory",
|
||||
&metal::get_cache_memory,
|
||||
R"pbdoc(
|
||||
Get the cache size in bytes.
|
||||
|
||||
The cache includes memory not currently used that has not been returned
|
||||
to the system allocator.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_memory_limit",
|
||||
&metal::set_memory_limit,
|
||||
"limit"_a,
|
||||
py::kw_only(),
|
||||
"relaxed"_a = true,
|
||||
R"pbdoc(
|
||||
Set the memory limit.
|
||||
|
||||
Memory allocations will wait on scheduled tasks to complete if the limit
|
||||
is exceeded. If there are no more scheduled tasks an error will be raised
|
||||
if ``relaxed`` is ``False``. Otherwise memory will be allocated
|
||||
(including the potential for swap) if ``relaxed`` is ``True``.
|
||||
|
||||
The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
size reported by the device.
|
||||
|
||||
Args:
|
||||
limit (int): Memory limit in bytes.
|
||||
relaxed (bool, optional): If `False`` an error is raised if the limit
|
||||
is exceeded. Default: ``True``
|
||||
|
||||
Returns:
|
||||
int: The previous memory limit in bytes.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_cache_limit",
|
||||
&metal::set_cache_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the free cache limit.
|
||||
|
||||
If using more than the given limit, free memory will be reclaimed
|
||||
from the cache on the next allocation. To disable the cache, set
|
||||
the limit to ``0``.
|
||||
|
||||
The cache limit defaults to the memory limit. See
|
||||
:func:`set_memory_limit` for more details.
|
||||
|
||||
Args:
|
||||
limit (int): The cache limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous cache limit in bytes.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
@@ -3081,7 +3081,7 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
2D convolution over an input with several channels
|
||||
|
||||
@@ -3105,6 +3105,114 @@ void init_ops(py::module_& m) {
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_general",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::vector<int>>& stride,
|
||||
const std::variant<
|
||||
int,
|
||||
std::vector<int>,
|
||||
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
||||
const std::variant<int, std::vector<int>>& kernel_dilation,
|
||||
const std::variant<int, std::vector<int>>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> stride_vec;
|
||||
std::vector<int> padding_lo_vec;
|
||||
std::vector<int> padding_hi_vec;
|
||||
std::vector<int> kernel_dilation_vec;
|
||||
std::vector<int> input_dilation_vec;
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_vec.push_back(*pv);
|
||||
} else {
|
||||
stride_vec = std::get<std::vector<int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_lo_vec.push_back(*pv);
|
||||
padding_hi_vec.push_back(*pv);
|
||||
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
|
||||
padding_lo_vec = *pv;
|
||||
padding_hi_vec = *pv;
|
||||
} else {
|
||||
auto [pl, ph] =
|
||||
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
|
||||
padding_lo_vec = pl;
|
||||
padding_hi_vec = ph;
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
|
||||
kernel_dilation_vec.push_back(*pv);
|
||||
} else {
|
||||
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&input_dilation); pv) {
|
||||
input_dilation_vec.push_back(*pv);
|
||||
} else {
|
||||
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
||||
}
|
||||
|
||||
return conv_general(
|
||||
/* const array& input = */ input,
|
||||
/* const array& weight = */ weight,
|
||||
/* std::vector<int> stride = */ stride_vec,
|
||||
/* std::vector<int> padding_lo = */ padding_lo_vec,
|
||||
/* std::vector<int> padding_hi = */ padding_lo_vec,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_vec,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_vec,
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ flip,
|
||||
s);
|
||||
},
|
||||
"input"_a,
|
||||
"weight"_a,
|
||||
py::pos_only(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"kernel_dilation"_a = 1,
|
||||
"input_dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
"flip"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
General convolution over an input with several channels
|
||||
|
||||
.. note::
|
||||
|
||||
* Only 1d and 2d convolutions are supported at the moment
|
||||
* the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, ..., C_in)``
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``
|
||||
stride (int or list(int), optional): :obj:`list` with kernel strides.
|
||||
All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int, list(int), or tuple(list(int), list(int)), optional):
|
||||
:obj:`list` with input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
kernel_dilation (int or list(int), optional): :obj:`list` with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
input_dilation (int or list(int), optional): :obj:`list` with
|
||||
input dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
flip (bool, optional): Flip the order in which the spatial dimensions of
|
||||
the weights are processed. Performs the cross-correlation operator when
|
||||
``flip`` is ``False`` and the convolution operator otherwise.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"save",
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
@@ -3638,62 +3746,69 @@ void init_ops(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_1d",
|
||||
&atleast_1d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_1d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_1d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least one dimension.
|
||||
Convert all arrays to have at least one dimension.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least one dimension.
|
||||
|
||||
array or list(array): An array or list of arrays with at least one dimension.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_2d",
|
||||
&atleast_2d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_2d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_2d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least two dimensions.
|
||||
Convert all arrays to have at least two dimensions.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least two dimensions.
|
||||
|
||||
array or list(array): An array or list of arrays with at least two dimensions.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"atleast_3d",
|
||||
&atleast_3d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_3d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_3d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least three dimensions.
|
||||
Convert all arrays to have at least three dimensions.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least three dimensions.
|
||||
|
||||
array or list(array): An array or list of arrays with at least three dimensions.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
60
python/src/pybind11_numpy_fp16.h
Normal file
60
python/src/pybind11_numpy_fp16.h
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
// A patch to get float16_t to work with pybind11 numpy arrays
|
||||
// Derived from:
|
||||
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
namespace pybind11::detail {
|
||||
|
||||
template <typename T>
|
||||
struct npy_scalar_caster {
|
||||
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
|
||||
using Array = array_t<T>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
|
||||
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
|
||||
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
|
||||
return false;
|
||||
Array tmp = Array::ensure(src);
|
||||
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
|
||||
this->value = *tmp.data();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle cast(T src, return_value_policy, handle) {
|
||||
Array tmp({1});
|
||||
tmp.mutable_at(0) = src;
|
||||
tmp.resize({});
|
||||
// You could also just return the array if you want a scalar array.
|
||||
object scalar = tmp[tuple()];
|
||||
return scalar.release();
|
||||
}
|
||||
};
|
||||
|
||||
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
|
||||
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
|
||||
constexpr int NPY_FLOAT16 = 23;
|
||||
|
||||
// Kinda following:
|
||||
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
|
||||
template <>
|
||||
struct npy_format_descriptor<float16_t> {
|
||||
static constexpr auto name = _("float16");
|
||||
static pybind11::dtype dtype() {
|
||||
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
|
||||
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
|
||||
static constexpr auto name = _("float16");
|
||||
};
|
||||
|
||||
} // namespace pybind11::detail
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
@@ -30,246 +31,6 @@ std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||
return vals;
|
||||
}
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree) ||
|
||||
py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
for (auto item : py::cast<py::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||
int len = py::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||
if (py::isinstance<py::list>(subtrees[0])) {
|
||||
py::list l;
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::list>(subtrees[j])) {
|
||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||
py::tuple l(len);
|
||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l[i] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||
py::dict d;
|
||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
d[item.first] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else {
|
||||
return transform(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||
return transform(inputs[0]);
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor) {
|
||||
std::function<py::object(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree)) {
|
||||
auto l = py::cast<py::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
return py::cast<py::object>(subtree);
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
auto d = py::cast<py::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else if (py::isinstance<array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return py::cast<py::object>(subtree);
|
||||
}
|
||||
};
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](py::handle node) {
|
||||
auto arr = py::cast<array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return py::cast(it->second);
|
||||
}
|
||||
return py::cast(arr);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return flat_tree;
|
||||
}
|
||||
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
py::object structure_sentinel() {
|
||||
static py::object sentinel;
|
||||
|
||||
if (sentinel.ptr() == nullptr) {
|
||||
sentinel = py::capsule(&sentinel);
|
||||
// probably not needed but this should make certain that we won't ever
|
||||
// delete the sentinel
|
||||
sentinel.inc_ref();
|
||||
}
|
||||
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict = true) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return py::cast<py::object>(obj);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return {flat_tree, structure};
|
||||
}
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](py::handle obj) {
|
||||
if (obj.is(sentinel)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
@@ -582,9 +343,69 @@ struct PyCompiledFun {
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto inputs = tree_flatten(args, false);
|
||||
// Flat array inputs
|
||||
std::vector<array> inputs;
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
|
||||
// Compilation constants which includes the tree structure of the arguments
|
||||
std::vector<uint64_t> constants;
|
||||
|
||||
// Reserve some large primes to signify the presence of an array, a list or
|
||||
// a dict in order to encode the structure of the pytree. We choose primes
|
||||
// to reduce slightly the chances of these numbers occurring by a
|
||||
// multiplication as values in the constants list.
|
||||
constexpr uint64_t array_identifier = 18446744073709551557UL;
|
||||
constexpr uint64_t list_identifier = 18446744073709551533UL;
|
||||
constexpr uint64_t dict_identifier = 18446744073709551521UL;
|
||||
|
||||
// Flatten the tree with hashed constants and structure
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle obj) {
|
||||
if (py::isinstance<py::list>(obj)) {
|
||||
auto l = py::cast<py::list>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
recurse(l[i]);
|
||||
}
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
auto l = py::cast<py::tuple>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (auto item : obj) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
auto d = py::cast<py::dict>(obj);
|
||||
constants.push_back(dict_identifier);
|
||||
for (auto item : d) {
|
||||
auto r = py::hash(item.first);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
recurse(item.second);
|
||||
}
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
inputs.push_back(py::cast<array>(obj));
|
||||
constants.push_back(array_identifier);
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
auto r = py::hash(obj);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
auto r = obj.cast<int64_t>();
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
auto r = obj.cast<double>();
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Function arguments must be trees of arrays "
|
||||
<< "or constants (floats, ints, or strings), but received "
|
||||
<< "type " << obj.get_type() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
};
|
||||
|
||||
recurse(args);
|
||||
int num_args = inputs.size();
|
||||
recurse(kwargs);
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||
const std::vector<array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
@@ -619,14 +440,6 @@ struct PyCompiledFun {
|
||||
return outputs;
|
||||
};
|
||||
|
||||
{
|
||||
auto flat_kwargs = tree_flatten(kwargs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
std::make_move_iterator(flat_kwargs.begin()),
|
||||
std::make_move_iterator(flat_kwargs.end()));
|
||||
}
|
||||
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
@@ -635,36 +448,6 @@ struct PyCompiledFun {
|
||||
std::make_move_iterator(flat_in_captures.end()));
|
||||
}
|
||||
|
||||
// Collect the compilation constants
|
||||
std::vector<uint64_t> constants;
|
||||
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
|
||||
// Consider expanding tuples to their contents including start and end
|
||||
// ids
|
||||
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
|
||||
auto r = py::hash(o);
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::int_>(o)) {
|
||||
auto r = o.cast<int64_t>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::float_>(o)) {
|
||||
auto r = o.cast<double>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
if (auto h = value_hash(args[i]); h.has_value()) {
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
for (auto& pair : kwargs) {
|
||||
if (auto h = value_hash(pair.second); h.has_value()) {
|
||||
constants.push_back(*value_hash(pair.first));
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
@@ -1017,7 +800,38 @@ void init_transforms(py::module_& m) {
|
||||
const py::object& inputs,
|
||||
const py::object& outputs,
|
||||
bool shapeless) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
std::ostringstream doc;
|
||||
auto name = fun.attr("__name__").cast<std::string>();
|
||||
doc << name;
|
||||
|
||||
// Try to get the signature
|
||||
auto inspect = py::module::import("inspect");
|
||||
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
|
||||
doc << inspect.attr("signature")(fun)
|
||||
.attr("__str__")()
|
||||
.cast<std::string>();
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
|
||||
doc << "\n\n";
|
||||
auto dstr = d.cast<std::string>();
|
||||
// Add spaces to match first line indentation with remainder of
|
||||
// docstring
|
||||
int i = 0;
|
||||
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
|
||||
doc << ' ';
|
||||
}
|
||||
doc << dstr;
|
||||
}
|
||||
auto doc_str = doc.str();
|
||||
return py::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
py::name(name.c_str()),
|
||||
py::doc(doc_str.c_str()));
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
|
||||
243
python/src/trees.cpp
Normal file
243
python/src/trees.cpp
Normal file
@@ -0,0 +1,243 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "python/src/trees.h"
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree) ||
|
||||
py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
for (auto item : py::cast<py::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||
int len = py::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||
if (py::isinstance<py::list>(subtrees[0])) {
|
||||
py::list l;
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::list>(subtrees[j])) {
|
||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||
py::tuple l(len);
|
||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l[i] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||
py::dict d;
|
||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
d[item.first] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else {
|
||||
return transform(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||
return transform(inputs[0]);
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor) {
|
||||
std::function<py::object(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree)) {
|
||||
auto l = py::cast<py::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
return py::cast<py::object>(subtree);
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
auto d = py::cast<py::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else if (py::isinstance<array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return py::cast<py::object>(subtree);
|
||||
}
|
||||
};
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](py::handle node) {
|
||||
auto arr = py::cast<array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return py::cast(it->second);
|
||||
}
|
||||
return py::cast(arr);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return flat_tree;
|
||||
}
|
||||
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index /* = 0 */) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
py::object structure_sentinel() {
|
||||
static py::object sentinel;
|
||||
|
||||
if (sentinel.ptr() == nullptr) {
|
||||
sentinel = py::capsule(&sentinel);
|
||||
// probably not needed but this should make certain that we won't ever
|
||||
// delete the sentinel
|
||||
sentinel.inc_ref();
|
||||
}
|
||||
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict /* = true */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return py::cast<py::object>(obj);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return {flat_tree, structure};
|
||||
}
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index /* = 0 */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](py::handle obj) {
|
||||
if (obj.is(sentinel)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
60
python/src/trees.h
Normal file
60
python/src/trees.h
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform);
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform);
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor);
|
||||
|
||||
/**
|
||||
* Fill a pytree (recursive dict or list of dict or list) in place with the
|
||||
* given arrays. */
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values);
|
||||
|
||||
/**
|
||||
* Replace all the arrays from the src values with the dst values in the
|
||||
* tree.
|
||||
*/
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst);
|
||||
|
||||
/**
|
||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||
* function will throw if the tree contains a leaf which is not an array.
|
||||
*/
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true);
|
||||
|
||||
/**
|
||||
* Unflatten a tree from a vector of arrays.
|
||||
*/
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0);
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict = true);
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index = 0);
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import unittest
|
||||
import weakref
|
||||
from itertools import permutations
|
||||
@@ -1440,6 +1441,15 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b @= a
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
|
||||
def test_load_from_pickled_np(self):
|
||||
a = np.array([1, 2, 3], dtype=np.int32)
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
a = np.array([1.0, 2.0, 3.0], dtype=np.float16)
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 8.0)
|
||||
|
||||
def test_power_grad(self):
|
||||
def fun(x, y):
|
||||
res = x - y
|
||||
return res**x
|
||||
|
||||
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
|
||||
self.assertEqual(grad.item(), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -539,6 +539,72 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
# Test nested constant
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y[0][0] == 1:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), [[1]])
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), [[0]])
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, a, b):
|
||||
for ai in a:
|
||||
for bi in b:
|
||||
x = bi * x + ai
|
||||
return x
|
||||
|
||||
z = fun(mx.array(1), [1, 1], [2])
|
||||
self.assertEqual(z.item(), 7)
|
||||
|
||||
z = fun(mx.array(1), [1], [1, 2])
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
counter = [0]
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
counter[0] += 1
|
||||
return x + y
|
||||
|
||||
z = fun(mx.array(1), 1)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(1, mx.array(1))
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
def test_compile_inf(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return mx.isinf(x + 2)
|
||||
|
||||
out = fun(mx.array([0.0]))
|
||||
self.assertEqual(out.item(), False)
|
||||
|
||||
def test_unsupported_input_types(self):
|
||||
|
||||
class MyClass:
|
||||
value = 1
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
return x + y.value
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), MyClass())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), y=MyClass())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
@@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
@@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
|
||||
):
|
||||
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
run_conv2D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
def __conv_general_test(
|
||||
self,
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride=1,
|
||||
padding=0,
|
||||
kernel_dilation=1,
|
||||
input_dilation=1,
|
||||
groups=1,
|
||||
flip=False,
|
||||
np_dtype=np.float32,
|
||||
atol=1e-5,
|
||||
):
|
||||
|
||||
with self.subTest(
|
||||
in_shape=in_shape,
|
||||
wt_shape=wt_shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
np_dtype=np_dtype,
|
||||
):
|
||||
|
||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv_general(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
def conv_general_pt(
|
||||
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
|
||||
):
|
||||
|
||||
C = inp.size()[1]
|
||||
ndim = inp.ndim - 2
|
||||
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
|
||||
|
||||
stride, padding, kernel_dilation, input_dilation = map(
|
||||
map_ints, (stride, padding, kernel_dilation, input_dilation)
|
||||
)
|
||||
|
||||
torch_convt_list = (
|
||||
F.conv_transpose1d,
|
||||
F.conv_transpose2d,
|
||||
F.conv_transpose3d,
|
||||
)
|
||||
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
|
||||
|
||||
conv_f = torch_conv_list[ndim - 1]
|
||||
convt_f = torch_convt_list[ndim - 1]
|
||||
|
||||
if flip:
|
||||
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
|
||||
|
||||
if not np.all(input_dilation == 1):
|
||||
ones = torch.ones(
|
||||
[C]
|
||||
+ [
|
||||
1,
|
||||
]
|
||||
* (ndim + 1)
|
||||
).to(inp.dtype)
|
||||
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
|
||||
|
||||
return conv_f(
|
||||
inp,
|
||||
wt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=kernel_dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
out_pt = conv_general_pt(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
|
||||
|
||||
self.assertEqual(out_mx.shape, out_pt.shape)
|
||||
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_general(self):
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (1, 1)
|
||||
padding = (2, 2)
|
||||
kernel_dilation = (2, 3)
|
||||
input_dilation = (1, 1)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 2)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 5)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 1)
|
||||
input_dilation = (2, 5)
|
||||
flip = True
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load("test.safetensors", return_metadata=True)
|
||||
res = mx.load(test_file, return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
|
||||
45
python/tests/test_metal.py
Normal file
45
python/tests/test_metal.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMetal(mlx_tests.MLXTestCase):
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.metal.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.metal.set_memory_limit(10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
active_mem = mx.metal.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
|
||||
new_active_mem = mx.metal.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.metal.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
@@ -8,7 +8,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
@@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
@@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_upsample(self):
|
||||
b, h, w, c = 1, 2, 2, 1
|
||||
scale_factor = 2
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=False
|
||||
)
|
||||
upsample_nearest_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=False
|
||||
)
|
||||
# Test single feature map, align corners
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
expected_nearest = mx.array(
|
||||
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.333333, 0.666667, 1],
|
||||
[0.666667, 1, 1.33333, 1.66667],
|
||||
[1.33333, 1.66667, 2, 2.33333],
|
||||
[2, 2.33333, 2.66667, 3],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
# Test single feature map, no align corners
|
||||
x = (
|
||||
mx.arange(1, b * h * w * c + 1)
|
||||
.reshape((b, c, h, w))
|
||||
.transpose((0, 2, 3, 1))
|
||||
)
|
||||
expected_bilinear_no_align_corners = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1.0000, 1.2500, 1.7500, 2.0000],
|
||||
[1.5000, 1.7500, 2.2500, 2.5000],
|
||||
[2.5000, 2.7500, 3.2500, 3.5000],
|
||||
[3.0000, 3.2500, 3.7500, 4.0000],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_nearest_no_align_corners = mx.array(
|
||||
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_bilinear_no_align_corners(x),
|
||||
expected_bilinear_no_align_corners,
|
||||
)
|
||||
)
|
||||
|
||||
# Test a more complex batch
|
||||
b, h, w, c = 2, 3, 3, 2
|
||||
scale_factor = 2
|
||||
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
|
||||
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
|
||||
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
|
||||
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
|
||||
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
|
||||
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
|
||||
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
|
||||
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
|
||||
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
|
||||
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
|
||||
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
|
||||
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
|
||||
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
|
||||
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
|
||||
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
|
||||
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
|
||||
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
|
||||
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
|
||||
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
|
||||
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
|
||||
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test different height and width scale_factor
|
||||
b, h, w, c = 1, 2, 2, 2
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
],
|
||||
[
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.2, 0.4, 0.6, 0.8, 1],
|
||||
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
|
||||
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
|
||||
[2, 2.2, 2.4, 2.6, 2.8, 3],
|
||||
],
|
||||
[
|
||||
[4, 4.2, 4.4, 4.6, 4.8, 5],
|
||||
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
|
||||
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
|
||||
[6, 6.2, 6.4, 6.6, 6.8, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test repr
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=2)),
|
||||
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=(2, 3))),
|
||||
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
|
||||
)
|
||||
|
||||
def test_pooling(self):
|
||||
# Test 1d pooling
|
||||
x = mx.array(
|
||||
|
||||
@@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.arange(0, float("inf"), float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
INT_MAX = 2147483647
|
||||
a = mx.arange(0, INT_MAX + 1, 1)
|
||||
|
||||
a = mx.arange(5)
|
||||
expected = [0, 1, 2, 3, 4]
|
||||
@@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(0, 10, 100)
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(10, 0, 1)
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(10, 0, float("inf"))
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, 10, float("inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, -10, float("-inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
def test_unary_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
r_np = npop(x)
|
||||
@@ -1563,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
shape = (3, 4, 5)
|
||||
for dtype in ("int32", "float32"):
|
||||
for axis in (None, 0, 1, 2):
|
||||
for kth in (-2, 2):
|
||||
for kth in (-2, 0, 2):
|
||||
with self.subTest(dtype=dtype, axis=axis, kth=kth):
|
||||
np.random.seed(0)
|
||||
np_dtype = getattr(np, dtype)
|
||||
@@ -1579,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
||||
self.assertTrue(np.all(c_np <= top_k_mx))
|
||||
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
||||
|
||||
if kth >= 0:
|
||||
d_np = np.take(b_mx, np.arange(kth), axis=axis)
|
||||
self.assertTrue(np.all(d_np <= c_mx))
|
||||
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
||||
top_k_np = np.take(
|
||||
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
|
||||
)
|
||||
self.assertTrue(np.all(top_k_np <= top_k_mx))
|
||||
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
||||
N = a_mx.shape[axis] if axis is not None else a_mx.size
|
||||
M = top_k_mx.shape[axis or 0]
|
||||
self.assertEqual(M, (kth + N) % N)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("LOW_MEMORY", None) is not None,
|
||||
@@ -1906,12 +1935,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_1d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@@ -1936,12 +1969,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_2d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@@ -1966,12 +2003,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_3d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||
with self.assertRaises(ValueError):
|
||||
opt.schedulers.join_schedules(schedules, boundaries)
|
||||
boundaries = [2, 4]
|
||||
schedule = opt.schedulers.join_schedules(schedules, boundaries)
|
||||
self.assertEqual(schedule(0).item(), 3)
|
||||
self.assertEqual(schedule(1).item(), 3)
|
||||
self.assertEqual(schedule(2).item(), 4)
|
||||
self.assertEqual(schedule(3).item(), 4)
|
||||
self.assertEqual(schedule(5).item(), 5)
|
||||
self.assertEqual(schedule(7).item(), 5)
|
||||
|
||||
def test_linear_warmup_with_cosine_decay(self):
|
||||
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
|
||||
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
|
||||
cos_with_warmup = opt.schedulers.join_schedules(
|
||||
[warmup_schedule, cosine_schedule], [101]
|
||||
)
|
||||
self.assertEqual(cos_with_warmup(0), 0.0)
|
||||
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
||||
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
||||
|
||||
def test_compile_with_schedule(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||
|
||||
Reference in New Issue
Block a user