resolved conflicts

This commit is contained in:
Gabrijel Boduljak 2024-01-03 00:52:16 +01:00
commit b5c2630104
43 changed files with 1461 additions and 413 deletions

View File

@ -8,7 +8,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
# Third-Party Software

View File

@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.0.3)
set(MLX_VERSION 0.0.6)
endif()
# --------------------- Processor tests -------------------------

View File

@ -53,7 +53,7 @@ variety of examples, including:
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
- Large-scale text generation with
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).

View File

@ -125,6 +125,14 @@ if __name__ == "__main__":
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 1")
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")

View File

@ -10,8 +10,8 @@ import subprocess
project = "MLX"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = "0.0.5"
release = "0.0.5"
version = "0.0.6"
release = "0.0.6"
# -- General configuration ---------------------------------------------------

View File

@ -57,6 +57,7 @@ are the CPU and GPU.
python/random
python/transforms
python/fft
python/linalg
python/nn
python/optimizers
python/tree_utils

View File

@ -0,0 +1,11 @@
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
norm

View File

@ -20,6 +20,7 @@ Layers
Linear
Conv1d
Conv2d
BatchNorm
LayerNorm
RMSNorm
GroupNorm
@ -27,3 +28,6 @@ Layers
MultiHeadAttention
Sequential
QuantizedLinear
Dropout
Dropout2d

View File

@ -17,3 +17,6 @@ Loss Functions
nll_loss
smooth_l1_loss
triplet_loss
hinge_loss
huber_loss
log_cosh_loss

View File

@ -14,6 +14,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)

View File

@ -126,7 +126,7 @@ struct ReductionPlan {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
(x.flags().row_contiguous || x.flags().col_contiguous)) {
x.flags().contiguous) {
return ContiguousAllReduce;
}

View File

@ -19,6 +19,9 @@ namespace mlx::core::metal {
namespace {
// Catch things related to the main-thread static variables
static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool();
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
@ -110,15 +113,22 @@ MTL::Library* load_library(
} // namespace
Device::Device()
: pool_(NS::AutoreleasePool::alloc()->init()),
device_(load_device()),
library_map_({{"mlx", load_library(device_)}}) {}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
}
Device::~Device() {
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
e.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@ -126,7 +136,6 @@ Device::~Device() {
l.second->release();
}
device_->release();
pool_->release();
}
void Device::new_queue(int index) {
@ -235,6 +244,7 @@ void Device::register_library(
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
@ -277,18 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel(
}
Device& device(mlx::core::Device) {
static Device metal_device_;
return metal_device_;
static Device metal_device;
return metal_device;
}
NS::AutoreleasePool*& thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init();
return p;
std::shared_ptr<void> new_scoped_memory_pool() {
auto dtor = [](void* ptr) {
static_cast<NS::AutoreleasePool*>(ptr)->release();
};
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {
thread_autorelease_pool();
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}

View File

@ -67,7 +67,6 @@ class Device {
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private:
NS::AutoreleasePool* pool_;
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
@ -78,6 +77,5 @@ class Device {
};
Device& device(mlx::core::Device);
NS::AutoreleasePool*& thread_autorelease_pool();
} // namespace mlx::core::metal

View File

@ -112,80 +112,22 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// General reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
template <typename T, typename U, typename Op, int NDIM>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
#define instantiate_general_reduce_helper(name, itype, otype, op) \
template [[host_name("general_reduce_" #name)]] \
[[kernel]] void general_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
[[kernel]] void general_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce(name, itype, otype, op) \
instantiate_general_reduce_helper(name, itype, otype, op) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce(
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& reduction_size [[buffer(2)]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint tid [[threadgroup_position_in_grid]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
@ -193,7 +135,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
Op op;
// Each threadgroup handles 1 reduction
in += tid * reduction_size + lid * N_READS;
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid.x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
@ -201,7 +146,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
@ -210,11 +155,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize * N_READS;
in += lsize.x * N_READS;
}
// Sepate case for the last set as we close the reduction size
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
@ -240,24 +185,28 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid == 0) {
out[tid] = total_val;
if (lid.x == 0) {
op.atomic_update(out, total_val, tid.x);
}
}
#define instantiate_row_reduce(name, itype, otype, op) \
template [[host_name("row_reduce_" #name)]] \
[[kernel]] void row_reduce<itype, otype, op>( \
#define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& reduction_size [[buffer(2)]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint tid [[threadgroup_position_in_grid]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
@ -311,62 +260,26 @@ inline void _contiguous_strided_reduce(
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce(
[[kernel]] void col_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
out_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_col_reduce(name, itype, otype, op) \
template [[host_name("col_reduce_" #name)]] \
[[kernel]] void col_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
@ -377,82 +290,27 @@ template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
tid.xy,
lid.xy,
lsize.xy);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
const device size_t& in_dim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
template [[host_name("contiguous_strided_reduce_" #name)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] \
[[kernel]] void col_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
const device size_t& in_dim [[buffer(7)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_contiguous_strided_helper(name, itype, otype, op) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
@ -461,10 +319,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
#define instantiate_reduce(name, itype, otype, op) \
instantiate_all_reduce(name, itype, otype, op) \
instantiate_row_reduce(name, itype, otype, op) \
instantiate_col_reduce(name, itype, otype, op) \
instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_general_reduce(name, itype, otype, op)
instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \

View File

@ -50,6 +50,7 @@ std::function<void()> make_task(
bool retain_graph) {
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto pool = new_scoped_memory_pool();
for (auto& d : deps) {
d.wait();
}
@ -66,12 +67,6 @@ std::function<void()> make_task(
arr.detach();
}
p->set_value();
// Signal this thread to clear the pool on a synchroniztion.
scheduler::enqueue(s, []() {
thread_autorelease_pool()->release();
thread_autorelease_pool() =
NS::AutoreleasePool::alloc()->init();
});
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);

View File

@ -20,6 +20,7 @@ constexpr bool is_available() {
}
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task(
array& arr,

View File

@ -2,9 +2,11 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <sstream>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
@ -61,22 +63,47 @@ void all_reduce_dispatch(
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void row_reduce_dispatch(
void row_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
auto kernel =
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
// Prepare the arguments for the kernel
int n_reads = REDUCE_N_READS;
size_t reduction_size = in.size() / out.size();
size_t reduction_size = plan.shape.back();
size_t out_size = out.size();
auto shape = plan.shape;
auto strides = plan.strides;
shape.pop_back();
strides.pop_back();
size_t non_row_reductions = 1;
for (auto s : shape) {
non_row_reductions *= static_cast<size_t>(s);
}
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
}
for (auto s : rem_strides) {
strides.push_back(s);
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@ -91,92 +118,54 @@ void row_reduce_dispatch(
// Launch enough thread groups for each output
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void col_reduce_dispatch(
void strided_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
std::ostringstream kernel_name;
auto kernel =
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
bool encode_in_shape = false;
bool encode_ndim = false;
// If the slowest moving axis can be merged into the reductions,
// we call the column reduce kernel
// In this case, a linear index in the output corresponds to the
// linear index in the input where the reduction starts
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
kernel_name << "col_reduce_" << op_name << type_to_name(in);
}
// Otherwise, while all the reduction axes can be merged, the mapping between
// indices in the output and input require resolving using shapes and strides
else {
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
encode_in_shape = true;
// We check for a viable template with the required number of dimensions
// we only care about encoding non-reduced shapes and strides in the input
size_t non_reducing_dims = in.ndim() - axes_.size();
if (non_reducing_dims >= 1 &&
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << non_reducing_dims;
} else {
encode_ndim = true;
}
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
// Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
size_t out_size = out.size();
auto shape = plan.shape;
auto strides = plan.strides;
shape.pop_back();
strides.pop_back();
size_t non_col_reductions = 1;
for (auto s : shape) {
non_col_reductions *= static_cast<size_t>(s);
}
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
}
for (auto s : rem_strides) {
strides.push_back(s);
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
// Calculate the number of inputs to reduce and the stride b/w them
size_t reduction_size = 1;
size_t in_ndim = in.ndim();
size_t reduction_stride = in_size;
for (int i : axes_) {
reduction_size *= in.shape(i);
reduction_stride = std::min(reduction_stride, in.strides()[i]);
}
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
if (encode_in_shape) {
// Obtain the non-reducing shape and strides of the input to encode
std::vector<int> inp_shape_mod;
std::vector<size_t> inp_strides_mod;
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
inp_shape_mod.push_back(in.shape(i));
inp_strides_mod.push_back(in.strides()[i]);
}
}
size_t ndim = inp_shape_mod.size();
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
}
}
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Select block dimensions
@ -200,7 +189,8 @@ void col_reduce_dispatch(
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
MTL::Size grid_dims =
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
// We set shared memory to be exploited here for reductions within a
@ -216,60 +206,6 @@ void col_reduce_dispatch(
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void general_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
bool encode_ndim = true;
std::ostringstream kernel_name;
kernel_name << "general_reduce_" << op_name << type_to_name(in);
// Check for specialzed kernels for input ndim
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << in.ndim();
encode_ndim = false;
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
size_t ndim = in.ndim();
// We set the reducing strides to 0 to induce collisions for the reduction
std::vector<size_t> out_strides(ndim);
size_t stride = 1;
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
if (j >= 0 && axes_[j] == i) {
out_strides[i] = 0;
--j;
} else {
out_strides[i] = stride;
stride *= in.shape(i);
}
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > in_size) {
thread_group_size = in_size;
}
size_t nthreads = in_size;
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace
//////////////////////////////////////////////////////////////////////
@ -278,7 +214,7 @@ void general_reduce_dispatch(
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
array in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
@ -335,37 +271,47 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reduce
{
// Check for contiguous data
if (in.size() == in.data_size() &&
(in.flags().row_contiguous || in.flags().col_contiguous)) {
// Go to all reduce if reducing over all axes
if (axes_.size() == in.ndim()) {
std::vector<array> copies;
ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and
// recompute the plan.
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
copies.push_back(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}
// Reducing over everything and the data is all there no broadcasting or
// slicing etc.
if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
return;
}
// Use specialized kernels if the input is row contiguous and
// the reducing axes can be merged into one
// At least the last dimension is row contiguous and we are reducing over
// the last dim.
else if (
in.flags().row_contiguous && in.strides().back() == 1 &&
(axes_.back() - axes_.front()) == axes_.size() - 1) {
// If the fastest moving axis is being reduced, go to row reduce
if (axes_[0] == (in.ndim() - axes_.size())) {
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d);
}
// Otherwise go to to generalized strided reduce
// Note: bool isn't support here yet due to the use of atomics
// once that is updated, this should be the else condition of this
// branch
else if (in.dtype() != bool_) {
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
// At least the last two dimensions are contiguous and we are doing a
// strided reduce over these.
else if (
plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) {
strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
// Fall back to the general case
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
}
}
} // namespace mlx::core

View File

@ -7,6 +7,9 @@
namespace mlx::core::metal {
void new_stream(Stream) {}
std::shared_ptr<void> new_scoped_memory_pool() {
return nullptr;
}
std::function<void()> make_task(
array& arr,

175
mlx/linalg.cpp Normal file
View File

@ -0,0 +1,175 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include <ostream>
#include <vector>
#include "mlx/dtype.h"
#include "mlx/linalg.h"
namespace mlx::core::linalg {
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
}
inline array l2_norm(
const array& a,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (is_complex(a.dtype())) {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
} else {
return sqrt(sum(square(a, s), axis, keepdims, s), s);
}
}
inline array vector_norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
auto dtype = at_least_float(a.dtype());
if (ord == 0.0) {
return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
} else if (ord == 1.0) {
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == 2.0) {
return l2_norm(a, axis, keepdims, s);
} else if (ord == std::numeric_limits<double>::infinity()) {
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
} else {
return power(
sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
array(1.0 / ord, dtype),
s);
}
}
inline array matrix_norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
auto dtype = at_least_float(a.dtype());
auto row_axis = axis[0];
auto col_axis = axis[1];
if (ord == -1.0) {
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
return astype(
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype,
s);
} else if (ord == 1.0) {
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
return astype(
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype,
s);
} else if (ord == std::numeric_limits<double>::infinity()) {
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
return astype(
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
dtype,
s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
return astype(
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
dtype,
s);
} else if (ord == 2.0 || ord == -2.0) {
throw std::runtime_error(
"[linalg::norm] Singular value norms are not implemented.");
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
throw std::invalid_argument(msg.str());
}
}
inline array matrix_norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == "f" || ord == "fro") {
return l2_norm(a, axis, keepdims, s);
} else if (ord == "nuc") {
throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented.");
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
throw std::invalid_argument(msg.str());
}
}
array norm(
const array& a,
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
if (!axis) {
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
}
if (axis.value().size() > 2) {
throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm.");
}
return l2_norm(a, axis.value(), keepdims, s);
}
array norm(
const array& a,
const double ord,
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
std::vector<int> ax;
if (!axis) {
ax.resize(a.ndim());
std::iota(ax.begin(), ax.end(), 0);
} else {
ax = axis.value();
}
if (ax.size() == 1) {
return vector_norm(a, ord, ax, keepdims, s);
} else if (ax.size() == 2) {
return matrix_norm(a, ord, ax, keepdims, s);
} else {
throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm.");
}
}
array norm(
const array& a,
const std::string& ord,
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
std::vector<int> ax;
if (!axis) {
ax.resize(a.ndim());
std::iota(ax.begin(), ax.end(), 0);
} else {
ax = axis.value();
}
if (ax.size() != 2) {
std::ostringstream msg;
msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
<< " but received " << ax.size() << " axis/axes.";
throw std::invalid_argument(msg.str());
}
return matrix_norm(a, ord, ax, keepdims, s);
}
} // namespace mlx::core::linalg

63
mlx/linalg.h Normal file
View File

@ -0,0 +1,63 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/ops.h"
#include "mlx/stream.h"
namespace mlx::core::linalg {
/**
* Compute vector or matrix norms.
*
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
* - If axis is not provided but ord is, then x must be either 1D or 2D.
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
* for matrices) is computed along the given axes. At most 2 axes can be
* specified.
* - If both axis and ord are provided, then the corresponding matrix or vector
* norm is computed. At most 2 axes can be specified.
*/
array norm(
const array& a,
const double ord,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const double ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::string& ord,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const std::string& ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
return norm(a, std::vector<int>{axis}, keepdims, s);
}
} // namespace mlx::core::linalg

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/device.h"
#include "mlx/fft.h"
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "mlx/random.h"
#include "mlx/stream.h"

View File

@ -103,7 +103,9 @@ array uniform(
}
auto stream = to_stream(s);
auto range = subtract(high, low, stream);
auto lo = astype(low, dtype, stream);
auto hi = astype(high, dtype, stream);
auto range = subtract(hi, lo, stream);
auto out_shape = broadcast_shapes(shape, range.shape());
if (out_shape != shape) {
std::ostringstream msg;
@ -136,7 +138,7 @@ array uniform(
auto out = bits(shape, size_of(dtype), key, stream);
out = astype(divide(out, maxval, stream), dtype, stream);
out = minimum(out, upper, stream);
return add(multiply(range, out, stream), low, stream);
return add(multiply(range, out, stream), lo, stream);
}
array uniform(

View File

@ -35,6 +35,7 @@ struct StreamThread {
}
void thread_fn() {
auto thread_pool = metal::new_scoped_memory_pool();
metal::new_stream(stream);
while (true) {
std::function<void()> task;

View File

@ -33,10 +33,16 @@ from mlx.nn.layers.activations import (
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, InstanceNorm, LayerNorm, RMSNorm
from mlx.nn.layers.normalization import (
BatchNorm,
GroupNorm,
InstanceNorm,
LayerNorm,
RMSNorm,
)
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (

View File

@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module
class Dropout(Module):
"""Randomly zero a portion of the elements during training.
r"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
@ -32,4 +32,57 @@ class Dropout(Module):
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask.astype(x.dtype) * x
return (1 / self._p_1) * mask * x
class Dropout2d(Module):
r"""Apply 2D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e. the input shape should be
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
image height,``W`` is the input image width, and``C`` is the number of
input channels
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
which zeros individual entries, this layer zeros entire channels. This is
beneficial for early convolution layers where adjacent pixels are
correlated. In such case, traditional dropout may not effectively
regularize activations. For more details, see [1].
[1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
Efficient Object Localization Using Convolutional Networks. CVPR 2015.
Args:
p (float): Probability of zeroing a channel during training.
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError("The dropout probability should be in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if x.ndim not in (3, 4):
raise ValueError(
f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
)
if self._p_1 == 1 or not self.training:
return x
# Dropout is applied on the whole channel
# 3D input: (1, 1, C)
# 4D input: (B, 1, 1, C)
mask_shape = x.shape
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x

View File

@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
from typing import Tuple
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -252,3 +254,121 @@ class GroupNorm(Module):
)
x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x
class BatchNorm(Module):
r"""Applies Batch Normalization over a 2D or 3D input.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
batch, ``C`` is the number of features or channels, and ``L`` is the
sequence length. The output has the same shape as the input. For
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
the height and width respecitvely.
For more information on Batch Normalization, see the original paper `Batch
Normalization: Accelerating Deep Network Training by Reducing Internal
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
Args:
num_features (int): The feature dimension to normalize over.
eps (float, optional): A small additive constant for numerical
stability. Default: ``1e-5``.
momentum (float, optional): The momentum for updating the running
mean and variance. Default: ``0.1``.
affine (bool, optional): If ``True``, apply a learned affine
transformation after the normalization. Default: ``True``.
track_running_stats (bool, optional): If ``True``, track the
running mean and variance. Default: ``True``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((5, 4))
>>> bn = nn.BatchNorm(num_features=4, affine=True)
>>> output = bn(x)
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.track_running_stats = track_running_stats
if affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
if self.track_running_stats:
self._running_mean = mx.zeros((num_features,))
self._running_var = mx.ones((num_features,))
def _extra_repr(self):
return (
f"{self.num_features}, eps={self.eps}, "
f"momentum={self.momentum}, affine={'weight' in self}, "
f"track_running_stats={self.track_running_stats}"
)
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
Calculate the mean and variance of the input tensor.
Args:
x (mx.array): Input tensor.
Returns:
tuple: Tuple containing mean and variance.
"""
reduction_axes = tuple(range(0, x.ndim - 1))
means = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
if self.track_running_stats and self.training:
self._running_mean = (
1 - self.momentum
) * self._running_mean + self.momentum * means
self._running_var = (
1 - self.momentum
) * self._running_var + self.momentum * var
return means, var
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass of BatchNorm.
Args:
x (mx.array): Input tensor.
Returns:
mx.array: Output tensor.
"""
if x.ndim < 2 or x.ndim > 4:
raise ValueError(
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
)
if self.training or not self.track_running_stats:
means, var = self._calc_stats(x)
else:
means, var = self._running_mean, self._running_var
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x

View File

@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -131,10 +133,6 @@ def mse_loss(
f"targets shape {targets.shape}."
)
assert (
predictions.shape == targets.shape
), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match"
loss = mx.square(predictions - targets)
return _reduce(loss, reduction)
@ -283,3 +281,94 @@ def _reduce(loss: mx.array, reduction: str = "none"):
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
r"""
Computes the hinge loss between inputs and targets.
.. math::
\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})
Args:
inputs (array): The predicted values.
targets (array): The target values. They should be -1 or 1.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed hinge loss.
"""
loss = mx.maximum(1 - inputs * targets, 0)
return _reduce(loss, reduction)
def huber_loss(
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
) -> mx.array:
r"""
Computes the Huber loss between inputs and targets.
.. math::
L_{\delta}(a) =
\left\{ \begin{array}{ll}
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
\end{array} \right.
Args:
inputs (array): The predicted values.
targets (array): The target values.
delta (float, optional): The threshold at which to change between L1 and L2 loss.
Default: ``1.0``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed Huber loss.
"""
errors = inputs - targets
abs_errors = mx.abs(errors)
quadratic = mx.minimum(abs_errors, delta)
linear = abs_errors - quadratic
loss = 0.5 * quadratic**2 + delta * linear
return _reduce(loss, reduction)
def log_cosh_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
r"""
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
and like the L1 loss for large errors, reducing sensitivity to outliers. This
dual behavior offers a balanced, robust approach for regression tasks.
.. math::
\text{logcosh}(y_{\text{true}}, y_{\text{pred}}) =
\frac{1}{n} \sum_{i=1}^{n}
\log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))
Args:
inputs (array): The predicted values.
targets (array): The target values.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed log cosh loss.
"""
errors = inputs - targets
loss = mx.logaddexp(errors, -errors) - math.log(2)
return _reduce(loss, reduction)

View File

@ -11,6 +11,7 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
)
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)

View File

@ -510,6 +510,14 @@ void init_array(py::module_& m) {
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
.def_property_readonly(
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
.def_property_readonly(
"itemsize",
&array::itemsize,
R"pbdoc(The size of the array's datatype in bytes.)pbdoc")
.def_property_readonly(
"nbytes",
&array::nbytes,
R"pbdoc(The number of bytes in the array.)pbdoc")
// TODO, this makes a deep copy of the shape
// implement alternatives to use reference
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html

180
python/src/linalg.cpp Normal file
View File

@ -0,0 +1,180 @@
// Copyright © 2023 Apple Inc.
#include <variant>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/linalg.h"
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
void init_linalg(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
auto m = parent_module.def_submodule(
"linalg", "mlx.core.linalg: linear algebra routines.");
m.def(
"norm",
[](const array& a,
const std::variant<std::monostate, int, double, std::string>& ord_,
const std::variant<std::monostate, int, std::vector<int>>& axis_,
const bool keepdims,
const StreamOrDevice stream) {
std::optional<std::vector<int>> axis = std::nullopt;
if (auto pv = std::get_if<int>(&axis_); pv) {
axis = std::vector<int>{*pv};
} else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
axis = *pv;
}
if (std::holds_alternative<std::monostate>(ord_)) {
return norm(a, axis, keepdims, stream);
} else {
if (auto pv = std::get_if<std::string>(&ord_); pv) {
return norm(a, *pv, axis, keepdims, stream);
}
double ord;
if (auto pv = std::get_if<int>(&ord_); pv) {
ord = *pv;
} else {
ord = std::get<double>(ord_);
}
return norm(a, ord, axis, keepdims, stream);
}
},
"a"_a,
py::pos_only(),
"ord"_a = none,
"axis"_a = none,
"keepdims"_a = false,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Matrix or vector norm.
This function computes vector or matrix norms depending on the value of
the ``ord`` and ``axis`` parameters.
Args:
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
along the given ``axis``. Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
norms of these matrices are computed. If `axis` is ``None`` then
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
2-D) is returned. Default: ``None``.
keepdims (bool, optional): If ``True``, the axes which are normed over are
left in the result as dimensions with size one. Default ``False``.
Returns:
array: The output containing the norm(s).
Notes:
For values of ``ord < 1``, the result is, strictly speaking, not a
mathematical norm, but it may still be useful for various numerical
purposes.
The following norms can be calculated:
===== ============================ ==========================
ord norm for matrices norm for vectors
===== ============================ ==========================
None Frobenius norm 2-norm
'fro' Frobenius norm --
inf max(sum(abs(x), axis=1)) max(abs(x))
-inf min(sum(abs(x), axis=1)) min(abs(x))
0 -- sum(x != 0)
1 max(sum(abs(x), axis=0)) as below
-1 min(sum(abs(x), axis=0)) as below
2 2-norm (largest sing. value) as below
-2 smallest singular value as below
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
.. warning::
Nuclear norm and norms based on singular values are not yet implemented.
The Frobenius norm is given by [1]_:
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
The nuclear norm is the sum of the singular values.
Both the Frobenius and nuclear norm orders are only defined for
matrices and raise a ``ValueError`` when ``a.ndim != 2``.
References:
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
Examples:
>>> import mlx.core as mx
>>> from mlx.core import linalg as la
>>> a = mx.arange(9) - 4
>>> a
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
>>> b = a.reshape((3,3))
>>> b
array([[-4, -3, -2],
[-1, 0, 1],
[ 2, 3, 4]], dtype=int32)
>>> la.norm(a)
array(7.74597, dtype=float32)
>>> la.norm(b)
array(7.74597, dtype=float32)
>>> la.norm(b, 'fro')
array(7.74597, dtype=float32)
>>> la.norm(a, float("inf"))
array(4, dtype=float32)
>>> la.norm(b, float("inf"))
array(9, dtype=float32)
>>> la.norm(a, -float("inf"))
array(0, dtype=float32)
>>> la.norm(b, -float("inf"))
array(2, dtype=float32)
>>> la.norm(a, 1)
array(20, dtype=float32)
>>> la.norm(b, 1)
array(7, dtype=float32)
>>> la.norm(a, -1)
array(0, dtype=float32)
>>> la.norm(b, -1)
array(6, dtype=float32)
>>> la.norm(a, 2)
array(7.74597, dtype=float32)
>>> la.norm(a, 3)
array(5.84804, dtype=float32)
>>> la.norm(a, -3)
array(0, dtype=float32)
>>> c = mx.array([[ 1, 2, 3],
... [-1, 1, 4]])
>>> la.norm(c, axis=0)
array([1.41421, 2.23607, 5], dtype=float32)
>>> la.norm(c, axis=1)
array([3.74166, 4.24264], dtype=float32)
>>> la.norm(c, ord=1, axis=1)
array([6, 6], dtype=float32)
>>> m = mx.arange(8).reshape(2,2,2)
>>> la.norm(m, axis=(1,2))
array([3.74166, 11.225], dtype=float32)
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
)pbdoc");
}

View File

@ -15,6 +15,7 @@ void init_ops(py::module_&);
void init_transforms(py::module_&);
void init_random(py::module_&);
void init_fft(py::module_&);
void init_linalg(py::module_&);
PYBIND11_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) {
init_transforms(m);
init_random(m);
init_fft(m);
init_linalg(m);
m.attr("__version__") = TOSTRING(_VERSION_);
}

View File

@ -2129,7 +2129,7 @@ void init_ops(py::module_& m) {
singleton dimensions, defaults to `False`.
Returns:
array: The output array with the indices of the minimum values.
array: The output array with the indices of the maximum values.
)pbdoc");
m.def(
"sort",

View File

@ -569,7 +569,7 @@ void init_transforms(py::module_& m) {
return lvalue
# Returns lvalue, dlvalue/dparams
lvalue, grads = mx.value_and_grad(mse)
lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)
def lasso(params, inputs, targets, a=1.0, b=1.0):
outputs = forward(params, inputs)
@ -580,7 +580,7 @@ void init_transforms(py::module_& m) {
return loss, mse, l1
(loss, mse, l1), grads = mx.value_and_grad(lasso)
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
Args:
fun (function): A function which takes a variable number of

View File

@ -84,6 +84,8 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(1)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.itemsize, 4)
self.assertEqual(x.nbytes, 4)
self.assertEqual(x.shape, [])
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.item(), 1)

View File

@ -0,0 +1,94 @@
# Copyright © 2023 Apple Inc.
import itertools
import math
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self):
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
# Test when at least one axis is provided
for num_axes in range(1, len(shape)):
if num_axes == 1:
ords = vector_ords
else:
ords = matrix_ords
for axis in itertools.combinations(range(len(shape)), num_axes):
for keepdims in [True, False]:
for o in ords:
out_np = np.linalg.norm(
x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
)
with self.subTest(
shape=shape, ord=o, axis=axis, keepdims=keepdims
):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test only ord provided
for shape in [(3,), (2, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for o in [None, 1, -1, float("inf"), -float("inf")]:
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test no ord and no axis provided
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
with self.subTest(shape=shape, keepdims=keepdims):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def test_complex_norm(self):
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_np = np.random.uniform(size=shape).astype(
np.float32
) + 1j * np.random.uniform(size=shape).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np)
out_mx = mx.linalg.norm(x_mx)
with self.subTest(shape=shape):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes):
out_np = np.linalg.norm(x_np, axis=axis)
out_mx = mx.linalg.norm(x_mx, axis=axis)
with self.subTest(shape=shape, axis=axis):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
x_np = np.random.uniform(size=(4, 4)).astype(
np.float32
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np, ord="fro")
out_mx = mx.linalg.norm(x_mx, ord="fro")
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
if __name__ == "__main__":
unittest.main()

View File

@ -511,6 +511,143 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
def test_batch_norm(self):
mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
],
)
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
# test eval mode
bn.eval()
y = bn(x)
expected_y = mx.array(
[
[-0.15984, 1.73159, -1.25456, 1.57891],
[-0.872193, -1.4281, -0.414439, -0.228678],
[0.602743, -0.30566, -0.554687, 0.139639],
[0.252199, 0.29066, -0.599572, -0.0512532],
[0.594096, -0.0334829, 2.11359, -0.151081],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test_no_affine
bn = nn.BatchNorm(num_features=4, affine=False)
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test with 3D input
mx.random.seed(42)
N = 2
L = 4
C = 5
x = mx.random.normal((N, L, C), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm(num_features=C, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
self.assertTrue(x.shape == y.shape)
expected_y = mx.array(
[
[
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
],
[
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
],
]
)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
expected_mean = mx.array(
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
)
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):
y = bn(x)
def test_batch_norm_stats(self):
batch_size = 2
num_features = 4
h = 3
w = 3
momentum = 0.1
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=0)
variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size, h, w, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=(0, 1, 2))
variances = np.var(np_data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
def test_conv1d(self):
N = 5
L = 12
@ -772,6 +909,24 @@ class TestNN(mlx_tests.MLXTestCase):
y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
if __name__ == "__main__":
unittest.main()

View File

@ -13,7 +13,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_q, scales, biases = mx.quantize(w, 64, b)
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
self.assertTrue((errors <= scales[..., None] / 2).all())
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
def test_qmm(self):
key = mx.random.key(0)

View File

@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
self.assertEqual(a.dtype, mx.bfloat16)
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)

View File

@ -165,7 +165,7 @@ if __name__ == "__main__":
setup(
name="mlx",
version=get_version("0.0.5"),
version=get_version("0.0.6"),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.",

View File

@ -31,6 +31,7 @@ target_sources(tests PRIVATE
scheduler_tests.cpp
utils_tests.cpp
vmap_tests.cpp
linalg_tests.cpp
${METAL_TEST_SOURCES}
)

250
tests/linalg_tests.cpp Normal file
View File

@ -0,0 +1,250 @@
// Copyright © 2023 Apple Inc.
#include "doctest/doctest.h"
#include <cmath>
#include "mlx/mlx.h"
using namespace mlx::core;
using namespace mlx::core::linalg;
TEST_CASE("[mlx.core.linalg.norm] no ord") {
// Zero dimensions
array x(2.0);
CHECK_EQ(norm(x).item<float>(), 2.0f);
CHECK_THROWS(norm(x, 0));
x = array({1, 2, 3});
float expected = std::sqrt(1 + 4 + 9);
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, true).ndim(), 1);
CHECK_THROWS(norm(x, 1));
x = reshape(arange(9), {3, 3});
expected =
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
CHECK(array_equal(
norm(x, 0, false),
array(
{std::sqrt(0 + 3 * 3 + 6 * 6),
std::sqrt(1 + 4 * 4 + 7 * 7),
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>());
CHECK(allclose(
norm(x, 1, false),
array(
{std::sqrt(0 + 1 + 2 * 2),
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>());
x = reshape(arange(18), {2, 3, 3});
CHECK(allclose(
norm(x, 2, false),
array(
{
std::sqrt(0 + 1 + 2 * 2),
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(allclose(
norm(x, std::vector<int>{1, 2}, false),
array(
{std::sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
8 * 8),
std::sqrt(
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
15 * 15 + 16 * 16 + 17 * 17)},
{2}))
.item<bool>());
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
}
TEST_CASE("[mlx.core.linalg.norm] double ord") {
CHECK_THROWS(norm(array(0), 2.0));
array x({1, 2, 3});
float expected = std::sqrt(1 + 4 + 9);
CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
CHECK_THROWS(norm(x, 2.0, 1));
expected = 1 + 2 + 3;
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
expected = 1;
CHECK_EQ(
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
x = reshape(arange(9), {3, 3});
CHECK(allclose(
norm(x, 2.0, 0, false),
array(
{std::sqrt(0 + 3 * 3 + 6 * 6),
std::sqrt(1 + 4 * 4 + 7 * 7),
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>());
CHECK(allclose(
norm(x, 2.0, 1, false),
array(
{sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>());
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(15.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(21.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(3.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(15.0));
x = reshape(arange(18), {2, 3, 3});
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
CHECK(allclose(
norm(x, 3.0, 0),
array(
{9.,
10.00333222,
11.02199456,
12.06217728,
13.12502645,
14.2094363,
15.31340617,
16.43469751,
17.57113899},
{3, 3}))
.item<bool>());
CHECK(allclose(
norm(x, 3.0, 2),
array(
{2.08008382,
6.,
10.23127655,
14.5180117,
18.82291607,
23.13593104},
{2, 3}))
.item<bool>());
CHECK(
allclose(
norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
.item<bool>());
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(
norm(x, 1.0, 0),
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
.item<bool>());
}
TEST_CASE("[mlx.core.linalg.norm] string ord") {
array x({1, 2, 3});
CHECK_THROWS(norm(x, "fro"));
x = reshape(arange(9), {3, 3});
CHECK_THROWS(norm(x, "bad ord"));
CHECK_EQ(
norm(x, "f", std::vector<int>{0, 1}).item<float>(),
doctest::Approx(14.2828568570857));
CHECK_EQ(
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
doctest::Approx(14.2828568570857));
x = reshape(arange(18), {2, 3, 3});
CHECK(allclose(
norm(x, "fro", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(x, "fro", std::vector<int>{1, 2}),
array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{1, 0}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{1, 2}),
array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{2, 1}),
array({14.28285686, 39.7617907}))
.item<bool>());
}

View File

@ -260,6 +260,10 @@ TEST_CASE("test random uniform") {
// Non float type throws
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
// dtype respected
x = random::uniform(-.1, .1, {0}, bfloat16);
CHECK_EQ(x.dtype(), bfloat16);
// Check broadcasting
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
CHECK_EQ(x.shape(), std::vector<int>{3, 3});