Merge branch 'ml-explore:main' into main

This commit is contained in:
Luca Arnaboldi
2024-03-04 10:57:32 +01:00
committed by GitHub
94 changed files with 5858 additions and 1575 deletions

View File

@@ -37,6 +37,7 @@ from mlx.nn.layers.activations import (
relu,
relu6,
selu,
sigmoid,
silu,
softmax,
softplus,
@@ -67,3 +68,4 @@ from mlx.nn.layers.transformer import (
TransformerEncoder,
TransformerEncoderLayer,
)
from mlx.nn.layers.upsample import Upsample

View File

@@ -18,7 +18,7 @@ def _make_activation_module(f):
@partial(mx.compile, shapeless=True)
def sigmoid(x):
r"""Applies the element-wise function:
r"""Applies the sigmoid function.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
@@ -142,11 +142,11 @@ def log_sigmoid(x):
@partial(mx.compile, shapeless=True)
def gelu(x):
def gelu(x) -> mx.array:
r"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
@@ -185,11 +185,15 @@ def gelu_fast_approx(x):
.. math::
x = x \sigma\left(1.773 x\right)
x = x \sigma\left(1.702 x\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
References:
- https://github.com/hendrycks/GELUs
- https://arxiv.org/abs/1606.08415
"""
return x * mx.sigmoid(1.773 * x)
return x * mx.sigmoid(1.702 * x)
def glu(x: mx.array, axis: int = -1) -> mx.array:
@@ -199,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
@@ -260,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
@@ -297,7 +302,7 @@ class GLU(Module):
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``

View File

@@ -7,6 +7,42 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in v.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
class Module(dict):
"""Base class for building neural networks with MLX.
@@ -98,10 +134,13 @@ class Module(dict):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
super(Module, self).__getattribute__(key)
def __setattr__(self, key: str, val: Any):
self[key] = val
if isinstance(val, (mx.array, dict, list, tuple)):
self[key] = val
else:
super(Module, self).__setattr__(key, val)
def load_weights(
self,
@@ -245,31 +284,11 @@ class Module(dict):
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)
def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)
if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd
if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
return {
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in self.items()
if filter_fn(self, k, v)
}
def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module

View File

@@ -0,0 +1,205 @@
# Copyright © 2023-2024 Apple Inc.
import operator
from functools import reduce
from itertools import product
from typing import Literal, Tuple, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
def _scaled_indices(N, scale, align_corners, dim, ndims):
M = int(scale * N)
if align_corners:
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
else:
step = 1 / scale
start = ((M - 1) * step - N + 1) / 2
indices = mx.arange(M, dtype=mx.float32) * step - start
indices = mx.clip(indices, 0, N - 1)
shape = [1] * ndims
shape[dim] = -1
return indices.reshape(shape)
def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
def _linear_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices_l = mx.floor(indices)
indices_r = mx.ceil(indices)
weight = indices - indices_l
weight = mx.expand_dims(weight, -1)
return (
(indices_l.astype(mx.int32), 1 - weight),
(indices_r.astype(mx.int32), weight),
)
def upsample_nearest(x: mx.array, scale_factor: Tuple):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
# Integer scale_factors means we can simply expand-broadcast and reshape
if tuple(map(int, scale_factor)) == scale_factor:
shape = list(x.shape)
for d in range(dims):
shape.insert(2 + 2 * d, 1)
x = x.reshape(shape)
for d in range(dims):
shape[2 + 2 * d] = int(scale_factor[d])
x = mx.broadcast_to(x, shape)
for d in range(dims):
shape[d + 1] *= shape[d + 2]
shape.pop(d + 2)
x = x.reshape(shape)
return x
else:
B, *N, C = x.shape
indices = [slice(None)]
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_nearest_indices(n, s, i, dims))
indices = tuple(indices)
return x[indices]
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
B, *N, C = x.shape
# Compute the sampling grid
indices = []
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_linear_indices(n, s, align_corners, i, dims))
# Sample and compute the weights
samples = []
weights = []
for idx_weight in product(*indices):
idx, weight = zip(*idx_weight)
samples.append(x[(slice(None),) + idx])
weights.append(reduce(operator.mul, weight))
# Interpolate
return sum(wi * xi for wi, xi in zip(weights, samples))
class Upsample(Module):
r"""Upsample the input signal spatially.
The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
2``. The first is the batch dimension and the last is the feature
dimension.
For example, an audio signal would be 3D with 1 spatial dimension, an image
4D with 2 and so on and so forth.
There are two upsampling algorithms implemented nearest neighbor upsampling
and linear interpolation. Both can be applied to any number of spatial
dimensions and the linear interpolation will be bilinear, trilinear etc
when applied to more than one spatial dimension.
.. note::
When using one of the linear interpolation modes the ``align_corners``
argument changes how the corners are treated in the input image. If
``align_corners=True`` then the top and left edge of the input and
output will be matching as will the bottom right edge.
Parameters:
scale_factor (float or tuple): The multiplier for the spatial size.
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
Otherwise, the number of scale factors provided must match the
number of spatial dimensions.
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
``"linear"``. Default: ``"nearest"``.
align_corners (bool, optional): Changes the way the corners are treated
during ``"linear"`` upsampling. See the note above and the
examples below for more details. Default: ``False``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
>>> x
array([[[[1],
[2]],
[[3],
[4]]]], dtype=int32)
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
>>> n(x).squeeze()
array([[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4]], dtype=int32)
>>> b = nn.Upsample(scale_factor=2, mode='linear')
>>> b(x).squeeze()
array([[1, 1.25, 1.75, 2],
[1.5, 1.75, 2.25, 2.5],
[2.5, 2.75, 3.25, 3.5],
[3, 3.25, 3.75, 4]], dtype=float32)
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
>>> b(x).squeeze()
array([[1, 1.33333, 1.66667, 2],
[1.66667, 2, 2.33333, 2.66667],
[2.33333, 2.66667, 3, 3.33333],
[3, 3.33333, 3.66667, 4]], dtype=float32)
"""
def __init__(
self,
scale_factor: Union[float, Tuple],
mode: Literal["nearest", "linear"] = "nearest",
align_corners: bool = False,
):
super().__init__()
if mode not in ["nearest", "linear"]:
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
if isinstance(scale_factor, (list, tuple)):
self.scale_factor = tuple(map(float, scale_factor))
else:
self.scale_factor = float(scale_factor)
self.mode = mode
self.align_corners = align_corners
def _extra_repr(self) -> str:
return (
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
f"align_corners={self.align_corners}"
)
def __call__(self, x: mx.array) -> mx.array:
dims = x.ndim - 2
if dims <= 0:
raise ValueError(
f"[Upsample] The input should have at least 1 spatial "
f"dimension which means it should be at least 3D but "
f"{x.ndim}D was provided"
)
scale_factor = self.scale_factor
if isinstance(scale_factor, tuple):
if len(scale_factor) != dims:
raise ValueError(
f"[Upsample] One scale per spatial dimension is required but "
f"scale_factor={scale_factor} and the number of spatial "
f"dimensions were {dims}"
)
else:
scale_factor = (scale_factor,) * dims
if self.mode == "nearest":
return upsample_nearest(x, scale_factor)
else:
return upsample_linear(x, scale_factor, self.align_corners)

View File

@@ -1,11 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Callable, List
import mlx.core as mx
def exponential_decay(init: float, decay_rate: float):
def exponential_decay(init: float, decay_rate: float) -> Callable:
r"""Make an exponential decay scheduler.
Args:
@@ -30,7 +31,7 @@ def exponential_decay(init: float, decay_rate: float):
return schedule
def step_decay(init: float, decay_rate: float, step_size: int):
def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
r"""Make a step decay scheduler.
Args:
@@ -57,7 +58,7 @@ def step_decay(init: float, decay_rate: float, step_size: int):
return schedule
def cosine_decay(init: float, decay_steps: int):
def cosine_decay(init: float, decay_steps: int) -> Callable:
r"""Make a cosine decay scheduler.
Args:
@@ -84,3 +85,73 @@ def cosine_decay(init: float, decay_steps: int):
return init * decay
return scheduler
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
r"""Join multiple schedules to create a new schedule.
Args:
schedules (list(Callable)): A list of schedules. Schedule :math:`i+1`
receives a step count indicating the number of steps since
the :math:`i`-th boundary.
boundaries (list(int)): A list of integers of length ``len(schedules) - 1``
that indicates when to transition between schedules.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(12): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999938, dtype=float32)
"""
if len(schedules) == 0:
raise ValueError("Must provide at least 1 schedule to join.")
if len(schedules) != len(boundaries) + 1:
raise ValueError(
f"Received {len(boundaries)} boundaries but "
f"expected {len(schedules) - 1}."
)
def schedule(step):
output = schedules[0](step)
for boundary, schedule in zip(boundaries, schedules[1:]):
output = mx.where(step < boundary, output, schedule(step - boundary))
return output
return schedule
def linear_schedule(init: float, end: float, steps: int) -> Callable:
r"""Make a linear scheduler.
Args:
init (float): Initial value.
end (float): Final value.
steps (int): Number of steps to apply the schedule over. The value is
``end`` for any steps beyond ``steps``.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=warmup)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(101): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.1, dtype=float32)
"""
if steps < 1:
raise ValueError(f"steps must be greater than 0, but got {steps}.")
def step_fn(step):
step = mx.minimum(step, steps)
return step * ((end - init) / steps) + init
return step_fn

View File

@@ -14,6 +14,7 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cstdint>
#include <cstring>
@@ -7,6 +7,7 @@
#include <pybind11/numpy.h>
#include "python/src/indexing.h"
#include "python/src/pybind11_numpy_fp16.h"
#include "python/src/utils.h"
#include "mlx/ops.h"
@@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
shape.push_back(np_array.shape(i));
}
// Get dtype
auto type = np_array.dtype();
// Copy data and make array
if (type.is(py::dtype::of<int>())) {
if (py::isinstance<py::array_t<int32_t>>(np_array)) {
return np_array_to_mlx_contiguous<int32_t>(
np_array, shape, dtype.value_or(int32));
} else if (type.is(py::dtype::of<uint32_t>())) {
} else if (py::isinstance<py::array_t<uint32_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint32_t>(
np_array, shape, dtype.value_or(uint32));
} else if (type.is(py::dtype::of<bool>())) {
} else if (py::isinstance<py::array_t<bool>>(np_array)) {
return np_array_to_mlx_contiguous<bool>(
np_array, shape, dtype.value_or(bool_));
} else if (type.is(py::dtype::of<double>())) {
} else if (py::isinstance<py::array_t<double>>(np_array)) {
return np_array_to_mlx_contiguous<double>(
np_array, shape, dtype.value_or(float32));
} else if (type.is(py::dtype::of<float>())) {
} else if (py::isinstance<py::array_t<float>>(np_array)) {
return np_array_to_mlx_contiguous<float>(
np_array, shape, dtype.value_or(float32));
} else if (type.is(py::dtype("float16"))) {
} else if (py::isinstance<py::array_t<float16_t>>(np_array)) {
return np_array_to_mlx_contiguous<float>(
np_array, shape, dtype.value_or(float16));
} else if (type.is(py::dtype::of<uint8_t>())) {
} else if (py::isinstance<py::array_t<uint8_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint8_t>(
np_array, shape, dtype.value_or(uint8));
} else if (type.is(py::dtype::of<uint16_t>())) {
} else if (py::isinstance<py::array_t<uint16_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint16_t>(
np_array, shape, dtype.value_or(uint16));
} else if (type.is(py::dtype::of<uint64_t>())) {
} else if (py::isinstance<py::array_t<uint64_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint64_t>(
np_array, shape, dtype.value_or(uint64));
} else if (type.is(py::dtype::of<int8_t>())) {
} else if (py::isinstance<py::array_t<int8_t>>(np_array)) {
return np_array_to_mlx_contiguous<int8_t>(
np_array, shape, dtype.value_or(int8));
} else if (type.is(py::dtype::of<int16_t>())) {
} else if (py::isinstance<py::array_t<int16_t>>(np_array)) {
return np_array_to_mlx_contiguous<int16_t>(
np_array, shape, dtype.value_or(int16));
} else if (type.is(py::dtype::of<int64_t>())) {
} else if (py::isinstance<py::array_t<int64_t>>(np_array)) {
return np_array_to_mlx_contiguous<int64_t>(
np_array, shape, dtype.value_or(int64));
} else if (type.is(py::dtype::of<std::complex<float>>())) {
} else if (py::isinstance<py::array_t<std::complex<float>>>(np_array)) {
return np_array_to_mlx_contiguous<std::complex<float>>(
np_array, shape, dtype.value_or(complex64));
} else if (type.is(py::dtype::of<std::complex<double>>())) {
} else if (py::isinstance<py::array_t<std::complex<double>>>(np_array)) {
return np_array_to_mlx_contiguous<std::complex<float>>(
np_array, shape, dtype.value_or(complex64));
} else {
std::ostringstream msg;
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
msg << "Cannot convert numpy array of type " << np_array.dtype()
<< " to mlx array.";
throw std::invalid_argument(msg.str());
}
}

View File

@@ -5,18 +5,88 @@
#include "mlx/backend/metal/metal.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_metal(py::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def("is_available", &metal::is_available);
metal.def(
"cache_enabled",
&metal::cache_enabled,
"check if metal buffer cache is enabled, default is true");
"is_available",
&metal::is_available,
R"pbdoc(
Check if the Metal back-end is available.
)pbdoc");
metal.def(
"set_cache_enabled",
&metal::set_cache_enabled,
"enable or disable metal buffer cache");
"get_active_memory",
&metal::get_active_memory,
R"pbdoc(
Get the actively used memory in bytes.
Note, this will not always match memory use reported by the system because
it does not include cached memory buffers.
)pbdoc");
metal.def(
"get_peak_memory",
&metal::get_peak_memory,
R"pbdoc(
Get the peak amount of used memory in bytes.
The maximum memory used is recorded from the beginning of the program
execution.
)pbdoc");
metal.def(
"get_cache_memory",
&metal::get_cache_memory,
R"pbdoc(
Get the cache size in bytes.
The cache includes memory not currently used that has not been returned
to the system allocator.
)pbdoc");
metal.def(
"set_memory_limit",
&metal::set_memory_limit,
"limit"_a,
py::kw_only(),
"relaxed"_a = true,
R"pbdoc(
Set the memory limit.
Memory allocations will wait on scheduled tasks to complete if the limit
is exceeded. If there are no more scheduled tasks an error will be raised
if ``relaxed`` is ``False``. Otherwise memory will be allocated
(including the potential for swap) if ``relaxed`` is ``True``.
The memory limit defaults to 1.5 times the maximum recommended working set
size reported by the device.
Args:
limit (int): Memory limit in bytes.
relaxed (bool, optional): If `False`` an error is raised if the limit
is exceeded. Default: ``True``
Returns:
int: The previous memory limit in bytes.
)pbdoc");
metal.def(
"set_cache_limit",
&metal::set_cache_limit,
"limit"_a,
R"pbdoc(
Set the free cache limit.
If using more than the given limit, free memory will be reclaimed
from the cache on the next allocation. To disable the cache, set
the limit to ``0``.
The cache limit defaults to the memory limit. See
:func:`set_memory_limit` for more details.
Args:
limit (int): The cache limit in bytes.
Returns:
int: The previous cache limit in bytes.
)pbdoc");
}

View File

@@ -3081,7 +3081,7 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
2D convolution over an input with several channels
@@ -3105,6 +3105,114 @@ void init_ops(py::module_& m) {
array: The convolved array.
)pbdoc");
m.def(
"conv_general",
[](const array& input,
const array& weight,
const std::variant<int, std::vector<int>>& stride,
const std::variant<
int,
std::vector<int>,
std::pair<std::vector<int>, std::vector<int>>>& padding,
const std::variant<int, std::vector<int>>& kernel_dilation,
const std::variant<int, std::vector<int>>& input_dilation,
int groups,
bool flip,
StreamOrDevice s) {
std::vector<int> stride_vec;
std::vector<int> padding_lo_vec;
std::vector<int> padding_hi_vec;
std::vector<int> kernel_dilation_vec;
std::vector<int> input_dilation_vec;
if (auto pv = std::get_if<int>(&stride); pv) {
stride_vec.push_back(*pv);
} else {
stride_vec = std::get<std::vector<int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_lo_vec.push_back(*pv);
padding_hi_vec.push_back(*pv);
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
padding_lo_vec = *pv;
padding_hi_vec = *pv;
} else {
auto [pl, ph] =
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
padding_lo_vec = pl;
padding_hi_vec = ph;
}
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
kernel_dilation_vec.push_back(*pv);
} else {
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
}
if (auto pv = std::get_if<int>(&input_dilation); pv) {
input_dilation_vec.push_back(*pv);
} else {
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
}
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ stride_vec,
/* std::vector<int> padding_lo = */ padding_lo_vec,
/* std::vector<int> padding_hi = */ padding_lo_vec,
/* std::vector<int> kernel_dilation = */ kernel_dilation_vec,
/* std::vector<int> input_dilation = */ input_dilation_vec,
/* int groups = */ groups,
/* bool flip = */ flip,
s);
},
"input"_a,
"weight"_a,
py::pos_only(),
"stride"_a = 1,
"padding"_a = 0,
"kernel_dilation"_a = 1,
"input_dilation"_a = 1,
"groups"_a = 1,
"flip"_a = false,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array
General convolution over an input with several channels
.. note::
* Only 1d and 2d convolutions are supported at the moment
* the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, ..., C_in)``
weight (array): Weight array of shape ``(C_out, ..., C_in)``
stride (int or list(int), optional): :obj:`list` with kernel strides.
All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int, list(int), or tuple(list(int), list(int)), optional):
:obj:`list` with input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
kernel_dilation (int or list(int), optional): :obj:`list` with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
input_dilation (int or list(int), optional): :obj:`list` with
input dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): Input feature groups. Default: ``1``.
flip (bool, optional): Flip the order in which the spatial dimensions of
the weights are processed. Performs the cross-correlation operator when
``flip`` is ``False`` and the convolution operator otherwise.
Default: ``False``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"save",
&mlx_save_helper,
"file"_a,
@@ -3638,62 +3746,69 @@ void init_ops(py::module_& m) {
)pbdoc");
m.def(
"atleast_1d",
&atleast_1d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_1d(arys[0].cast<array>(), s));
}
return py::cast(atleast_1d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least one dimension.
Convert all arrays to have at least one dimension.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least one dimension.
array or list(array): An array or list of arrays with at least one dimension.
)pbdoc");
m.def(
"atleast_2d",
&atleast_2d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_2d(arys[0].cast<array>(), s));
}
return py::cast(atleast_2d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least two dimensions.
Convert all arrays to have at least two dimensions.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least two dimensions.
array or list(array): An array or list of arrays with at least two dimensions.
)pbdoc");
m.def(
"atleast_3d",
&atleast_3d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_3d(arys[0].cast<array>(), s));
}
return py::cast(atleast_3d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least three dimensions.
Convert all arrays to have at least three dimensions.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least three dimensions.
array or list(array): An array or list of arrays with at least three dimensions.
)pbdoc");
}

View File

@@ -0,0 +1,60 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
// A patch to get float16_t to work with pybind11 numpy arrays
// Derived from:
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
#include <pybind11/numpy.h>
namespace pybind11::detail {
template <typename T>
struct npy_scalar_caster {
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
using Array = array_t<T>;
bool load(handle src, bool convert) {
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
return false;
Array tmp = Array::ensure(src);
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
this->value = *tmp.data();
return true;
}
return false;
}
static handle cast(T src, return_value_policy, handle) {
Array tmp({1});
tmp.mutable_at(0) = src;
tmp.resize({});
// You could also just return the array if you want a scalar array.
object scalar = tmp[tuple()];
return scalar.release();
}
};
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
constexpr int NPY_FLOAT16 = 23;
// Kinda following:
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
template <>
struct npy_format_descriptor<float16_t> {
static constexpr auto name = _("float16");
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
};
template <>
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
static constexpr auto name = _("float16");
};
} // namespace pybind11::detail

View File

@@ -11,6 +11,7 @@
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "python/src/trees.h"
namespace py = pybind11;
using namespace py::literals;
@@ -30,246 +31,6 @@ std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
return vals;
}
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
@@ -582,9 +343,69 @@ struct PyCompiledFun {
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
// Flat array inputs
std::vector<array> inputs;
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
// Compilation constants which includes the tree structure of the arguments
std::vector<uint64_t> constants;
// Reserve some large primes to signify the presence of an array, a list or
// a dict in order to encode the structure of the pytree. We choose primes
// to reduce slightly the chances of these numbers occurring by a
// multiplication as values in the constants list.
constexpr uint64_t array_identifier = 18446744073709551557UL;
constexpr uint64_t list_identifier = 18446744073709551533UL;
constexpr uint64_t dict_identifier = 18446744073709551521UL;
// Flatten the tree with hashed constants and structure
std::function<void(py::handle)> recurse;
recurse = [&](py::handle obj) {
if (py::isinstance<py::list>(obj)) {
auto l = py::cast<py::list>(obj);
constants.push_back(list_identifier);
for (int i = 0; i < l.size(); ++i) {
recurse(l[i]);
}
} else if (py::isinstance<py::tuple>(obj)) {
auto l = py::cast<py::tuple>(obj);
constants.push_back(list_identifier);
for (auto item : obj) {
recurse(item);
}
} else if (py::isinstance<py::dict>(obj)) {
auto d = py::cast<py::dict>(obj);
constants.push_back(dict_identifier);
for (auto item : d) {
auto r = py::hash(item.first);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
recurse(item.second);
}
} else if (py::isinstance<array>(obj)) {
inputs.push_back(py::cast<array>(obj));
constants.push_back(array_identifier);
} else if (py::isinstance<py::str>(obj)) {
auto r = py::hash(obj);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::int_>(obj)) {
auto r = obj.cast<int64_t>();
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::float_>(obj)) {
auto r = obj.cast<double>();
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else {
std::ostringstream msg;
msg << "[compile] Function arguments must be trees of arrays "
<< "or constants (floats, ints, or strings), but received "
<< "type " << obj.get_type() << ".";
throw std::invalid_argument(msg.str());
}
};
recurse(args);
int num_args = inputs.size();
recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
@@ -619,14 +440,6 @@ struct PyCompiledFun {
return outputs;
};
{
auto flat_kwargs = tree_flatten(kwargs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_kwargs.begin()),
std::make_move_iterator(flat_kwargs.end()));
}
if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
@@ -635,36 +448,6 @@ struct PyCompiledFun {
std::make_move_iterator(flat_in_captures.end()));
}
// Collect the compilation constants
std::vector<uint64_t> constants;
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
// Consider expanding tuples to their contents including start and end
// ids
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
auto r = py::hash(o);
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::int_>(o)) {
auto r = o.cast<int64_t>();
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::float_>(o)) {
auto r = o.cast<double>();
return *reinterpret_cast<uint64_t*>(&r);
} else {
return std::nullopt;
}
};
for (int i = 0; i < args.size(); i++) {
if (auto h = value_hash(args[i]); h.has_value()) {
constants.push_back(*h);
}
}
for (auto& pair : kwargs) {
if (auto h = value_hash(pair.second); h.has_value()) {
constants.push_back(*value_hash(pair.first));
constants.push_back(*h);
}
}
// Compile and call
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
@@ -1017,7 +800,38 @@ void init_transforms(py::module_& m) {
const py::object& inputs,
const py::object& outputs,
bool shapeless) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
py::options options;
options.disable_function_signatures();
std::ostringstream doc;
auto name = fun.attr("__name__").cast<std::string>();
doc << name;
// Try to get the signature
auto inspect = py::module::import("inspect");
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
doc << inspect.attr("signature")(fun)
.attr("__str__")()
.cast<std::string>();
}
// Try to get the doc string
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
doc << "\n\n";
auto dstr = d.cast<std::string>();
// Add spaces to match first line indentation with remainder of
// docstring
int i = 0;
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
doc << ' ';
}
doc << dstr;
}
auto doc_str = doc.str();
return py::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
py::name(name.c_str()),
py::doc(doc_str.c_str()));
},
"fun"_a,
"inputs"_a = std::nullopt,

243
python/src/trees.cpp Normal file
View File

@@ -0,0 +1,243 @@
// Copyright © 2023-2024 Apple Inc.
#include "python/src/trees.h"
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index /* = 0 */) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict /* = true */) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index /* = 0 */) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}

60
python/src/trees.h Normal file
View File

@@ -0,0 +1,60 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/array.h"
namespace py = pybind11;
using namespace mlx::core;
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform);
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform);
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor);
/**
* Fill a pytree (recursive dict or list of dict or list) in place with the
* given arrays. */
void tree_fill(py::object& tree, const std::vector<array>& values);
/**
* Replace all the arrays from the src values with the dst values in the
* tree.
*/
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst);
/**
* Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array.
*/
std::vector<array> tree_flatten(py::object tree, bool strict = true);
/**
* Unflatten a tree from a vector of arrays.
*/
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index = 0);
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true);
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0);

View File

@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import operator
import pickle
import unittest
import weakref
from itertools import permutations
@@ -1440,6 +1441,15 @@ class TestArray(mlx_tests.MLXTestCase):
b @= a
self.assertTrue(mx.array_equal(a, b))
def test_load_from_pickled_np(self):
a = np.array([1, 2, 3], dtype=np.int32)
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
a = np.array([1.0, 2.0, 3.0], dtype=np.float16)
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
if __name__ == "__main__":
unittest.main()

View File

@@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase):
_, vjps = mx.vjp(func, (arr,), (cotan,))
self.assertEqual(vjps[0].item(), 8.0)
def test_power_grad(self):
def fun(x, y):
res = x - y
return res**x
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
self.assertEqual(grad.item(), 1.0)
if __name__ == "__main__":
unittest.main()

View File

@@ -539,6 +539,72 @@ class TestCompile(mlx_tests.MLXTestCase):
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
# Test nested constant
@partial(mx.compile)
def fun(x, y):
if y[0][0] == 1:
return x + 1
else:
return x + 2
z = fun(mx.array(1), [[1]])
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), [[0]])
self.assertEqual(z.item(), 3)
@partial(mx.compile)
def fun(x, a, b):
for ai in a:
for bi in b:
x = bi * x + ai
return x
z = fun(mx.array(1), [1, 1], [2])
self.assertEqual(z.item(), 7)
z = fun(mx.array(1), [1], [1, 2])
self.assertEqual(z.item(), 5)
counter = [0]
@partial(mx.compile)
def fun(x, y):
counter[0] += 1
return x + y
z = fun(mx.array(1), 1)
self.assertEqual(z.item(), 2)
z = fun(1, mx.array(1))
self.assertEqual(z.item(), 2)
self.assertEqual(counter[0], 2)
def test_compile_inf(self):
@mx.compile
def fun(x):
return mx.isinf(x + 2)
out = fun(mx.array([0.0]))
self.assertEqual(out.item(), False)
def test_unsupported_input_types(self):
class MyClass:
value = 1
@mx.compile
def fun(x, y):
return x + y.value
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), MyClass())
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), y=MyClass())
if __name__ == "__main__":
unittest.main()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
@@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
@@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
for idim, kdim, stride, padding, dilation in (
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
):
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
run_conv2D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
def __conv_general_test(
self,
in_shape,
wt_shape,
stride=1,
padding=0,
kernel_dilation=1,
input_dilation=1,
groups=1,
flip=False,
np_dtype=np.float32,
atol=1e-5,
):
with self.subTest(
in_shape=in_shape,
wt_shape=wt_shape,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
np_dtype=np_dtype,
):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv_general(
in_mx,
wt_mx,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
):
C = inp.size()[1]
ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
stride, padding, kernel_dilation, input_dilation = map(
map_ints, (stride, padding, kernel_dilation, input_dilation)
)
torch_convt_list = (
F.conv_transpose1d,
F.conv_transpose2d,
F.conv_transpose3d,
)
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
conv_f = torch_conv_list[ndim - 1]
convt_f = torch_convt_list[ndim - 1]
if flip:
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
if not np.all(input_dilation == 1):
ones = torch.ones(
[C]
+ [
1,
]
* (ndim + 1)
).to(inp.dtype)
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
return conv_f(
inp,
wt,
stride=stride,
padding=padding,
dilation=kernel_dilation,
groups=groups,
)
out_pt = conv_general_pt(
in_pt,
wt_pt,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
self.assertEqual(out_mx.shape, out_pt.shape)
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_general(self):
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (1, 1)
padding = (2, 2)
kernel_dilation = (2, 3)
input_dilation = (1, 1)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 2)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 5)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 1)
input_dilation = (2, 5)
flip = True
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
if __name__ == "__main__":

View File

@@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
test_file = os.path.join(self.test_dir, "test.safetensors")
with self.assertRaises(Exception):
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
)
res = mx.load("test.safetensors", return_metadata=True)
res = mx.load(test_file, return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})

View File

@@ -0,0 +1,45 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.metal.get_cache_memory(), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
old_limit = mx.metal.set_memory_limit(10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
active_mem = mx.metal.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
new_active_mem = mx.metal.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.metal.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.metal.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import os
import tempfile
@@ -8,7 +8,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx.utils import tree_flatten, tree_map
class TestBase(mlx_tests.MLXTestCase):
@@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
@@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_upsample(self):
b, h, w, c = 1, 2, 2, 1
scale_factor = 2
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=False
)
upsample_nearest_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=False
)
# Test single feature map, align corners
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
expected_nearest = mx.array(
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.333333, 0.666667, 1],
[0.666667, 1, 1.33333, 1.66667],
[1.33333, 1.66667, 2, 2.33333],
[2, 2.33333, 2.66667, 3],
]
]
]
).transpose((0, 2, 3, 1))
# Test single feature map, no align corners
x = (
mx.arange(1, b * h * w * c + 1)
.reshape((b, c, h, w))
.transpose((0, 2, 3, 1))
)
expected_bilinear_no_align_corners = mx.array(
[
[
[
[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000],
]
]
]
).transpose((0, 2, 3, 1))
expected_nearest_no_align_corners = mx.array(
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
).transpose((0, 2, 3, 1))
self.assertTrue(
np.allclose(
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
)
)
self.assertTrue(
np.allclose(
upsample_bilinear_no_align_corners(x),
expected_bilinear_no_align_corners,
)
)
# Test a more complex batch
b, h, w, c = 2, 3, 3, 2
scale_factor = 2
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
],
[
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
],
],
[
[
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
],
[
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
],
[
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
],
],
[
[
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
],
[
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test different height and width scale_factor
b, h, w, c = 1, 2, 2, 2
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=(2, 3), mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=(2, 3), mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[2, 2, 2, 3, 3, 3],
[2, 2, 2, 3, 3, 3],
],
[
[4, 4, 4, 5, 5, 5],
[4, 4, 4, 5, 5, 5],
[6, 6, 6, 7, 7, 7],
[6, 6, 6, 7, 7, 7],
],
]
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.2, 0.4, 0.6, 0.8, 1],
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
[2, 2.2, 2.4, 2.6, 2.8, 3],
],
[
[4, 4.2, 4.4, 4.6, 4.8, 5],
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
[6, 6.2, 6.4, 6.6, 6.8, 7],
],
]
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test repr
self.assertEqual(
str(nn.Upsample(scale_factor=2)),
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
)
self.assertEqual(
str(nn.Upsample(scale_factor=(2, 3))),
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
)
def test_pooling(self):
# Test 1d pooling
x = mx.array(

View File

@@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.arange(0, float("inf"), float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, 5)
with self.assertRaises(ValueError):
INT_MAX = 2147483647
a = mx.arange(0, INT_MAX + 1, 1)
a = mx.arange(5)
expected = [0, 1, 2, 3, 4]
@@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(a.tolist(), expected)
self.assertEqual(a.dtype, mx.int32)
a = mx.arange(0, 10, 100)
expected = [0]
self.assertListEqual(a.tolist(), expected)
self.assertEqual(a.dtype, mx.int32)
a = mx.arange(10, 0, 1)
expected = []
self.assertListEqual(a.tolist(), expected)
a = mx.arange(10, 0, float("inf"))
expected = []
self.assertListEqual(a.tolist(), expected)
a = mx.arange(0, 10, float("inf"))
expected = [0]
self.assertListEqual(a.tolist(), expected)
a = mx.arange(0, -10, float("-inf"))
expected = [0]
self.assertListEqual(a.tolist(), expected)
def test_unary_ops(self):
def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x)
@@ -1563,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):
for axis in (None, 0, 1, 2):
for kth in (-2, 2):
for kth in (-2, 0, 2):
with self.subTest(dtype=dtype, axis=axis, kth=kth):
np.random.seed(0)
np_dtype = getattr(np, dtype)
@@ -1579,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
top_k_mx = mx.topk(a_mx, kth, axis=axis)
self.assertTrue(np.all(c_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
if kth >= 0:
d_np = np.take(b_mx, np.arange(kth), axis=axis)
self.assertTrue(np.all(d_np <= c_mx))
top_k_mx = mx.topk(a_mx, kth, axis=axis)
top_k_np = np.take(
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
)
self.assertTrue(np.all(top_k_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
N = a_mx.shape[axis] if axis is not None else a_mx.size
M = top_k_mx.shape[axis or 0]
self.assertEqual(M, (kth + N) % N)
@unittest.skipIf(
os.getenv("LOW_MEMORY", None) is not None,
@@ -1906,12 +1935,16 @@ class TestOps(mlx_tests.MLXTestCase):
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_1d(*mx_arrays)
for i, array in enumerate(arrays):
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
def test_atleast_2d(self):
def compare_nested_lists(x, y):
@@ -1936,12 +1969,16 @@ class TestOps(mlx_tests.MLXTestCase):
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_2d(*mx_arrays)
for i, array in enumerate(arrays):
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
def test_atleast_3d(self):
def compare_nested_lists(x, y):
@@ -1966,12 +2003,16 @@ class TestOps(mlx_tests.MLXTestCase):
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_3d(*mx_arrays)
for i, array in enumerate(arrays):
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
if __name__ == "__main__":

View File

@@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_schedule_joiner(self):
boundaries = [2, 3, 4]
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
with self.assertRaises(ValueError):
opt.schedulers.join_schedules(schedules, boundaries)
boundaries = [2, 4]
schedule = opt.schedulers.join_schedules(schedules, boundaries)
self.assertEqual(schedule(0).item(), 3)
self.assertEqual(schedule(1).item(), 3)
self.assertEqual(schedule(2).item(), 4)
self.assertEqual(schedule(3).item(), 4)
self.assertEqual(schedule(5).item(), 5)
self.assertEqual(schedule(7).item(), 5)
def test_linear_warmup_with_cosine_decay(self):
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
cos_with_warmup = opt.schedulers.join_schedules(
[warmup_schedule, cosine_schedule], [101]
)
self.assertEqual(cos_with_warmup(0), 0.0)
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
optimizer = opt.Adam(learning_rate=cos_with_warmup)
for _ in range(100):
optimizer.update({}, {})
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
for _ in range(100):
optimizer.update({}, {})
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule)