mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:38:07 +08:00
resolved conflicts
This commit is contained in:
commit
b5c2630104
@ -8,7 +8,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
|
|||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.0.3)
|
set(MLX_VERSION 0.0.6)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
@ -53,7 +53,7 @@ variety of examples, including:
|
|||||||
|
|
||||||
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
||||||
- Large-scale text generation with
|
- Large-scale text generation with
|
||||||
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
|
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
|
||||||
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
||||||
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
||||||
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
||||||
|
@ -125,6 +125,14 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||||
|
@ -10,8 +10,8 @@ import subprocess
|
|||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, MLX Contributors"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = "0.0.5"
|
version = "0.0.6"
|
||||||
release = "0.0.5"
|
release = "0.0.6"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ are the CPU and GPU.
|
|||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
python/fft
|
python/fft
|
||||||
|
python/linalg
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/tree_utils
|
python/tree_utils
|
||||||
|
11
docs/src/python/linalg.rst
Normal file
11
docs/src/python/linalg.rst
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
.. _linalg:
|
||||||
|
|
||||||
|
Linear Algebra
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.linalg
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
norm
|
@ -20,6 +20,7 @@ Layers
|
|||||||
Linear
|
Linear
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
|
BatchNorm
|
||||||
LayerNorm
|
LayerNorm
|
||||||
RMSNorm
|
RMSNorm
|
||||||
GroupNorm
|
GroupNorm
|
||||||
@ -27,3 +28,6 @@ Layers
|
|||||||
MultiHeadAttention
|
MultiHeadAttention
|
||||||
Sequential
|
Sequential
|
||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
|
Dropout
|
||||||
|
Dropout2d
|
||||||
|
|
||||||
|
@ -17,3 +17,6 @@ Loss Functions
|
|||||||
nll_loss
|
nll_loss
|
||||||
smooth_l1_loss
|
smooth_l1_loss
|
||||||
triplet_loss
|
triplet_loss
|
||||||
|
hinge_loss
|
||||||
|
huber_loss
|
||||||
|
log_cosh_loss
|
@ -14,6 +14,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ struct ReductionPlan {
|
|||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
(x.flags().row_contiguous || x.flags().col_contiguous)) {
|
x.flags().contiguous) {
|
||||||
return ContiguousAllReduce;
|
return ContiguousAllReduce;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,6 +19,9 @@ namespace mlx::core::metal {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Catch things related to the main-thread static variables
|
||||||
|
static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
// TODO nicer way to set this or possibly expose as an environment variable
|
// TODO nicer way to set this or possibly expose as an environment variable
|
||||||
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||||
|
|
||||||
@ -110,15 +113,22 @@ MTL::Library* load_library(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device()
|
Device::Device() {
|
||||||
: pool_(NS::AutoreleasePool::alloc()->init()),
|
auto pool = new_scoped_memory_pool();
|
||||||
device_(load_device()),
|
device_ = load_device();
|
||||||
library_map_({{"mlx", load_library(device_)}}) {}
|
library_map_ = {{"mlx", load_library(device_)}};
|
||||||
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
for (auto& q : queue_map_) {
|
for (auto& q : queue_map_) {
|
||||||
q.second->release();
|
q.second->release();
|
||||||
}
|
}
|
||||||
|
for (auto& b : buffer_map_) {
|
||||||
|
b.second.second->release();
|
||||||
|
}
|
||||||
|
for (auto& e : encoder_map_) {
|
||||||
|
e.second->release();
|
||||||
|
}
|
||||||
for (auto& k : kernel_map_) {
|
for (auto& k : kernel_map_) {
|
||||||
k.second->release();
|
k.second->release();
|
||||||
}
|
}
|
||||||
@ -126,7 +136,6 @@ Device::~Device() {
|
|||||||
l.second->release();
|
l.second->release();
|
||||||
}
|
}
|
||||||
device_->release();
|
device_->release();
|
||||||
pool_->release();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::new_queue(int index) {
|
void Device::new_queue(int index) {
|
||||||
@ -235,6 +244,7 @@ void Device::register_library(
|
|||||||
MTL::ComputePipelineState* Device::get_kernel(
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::string& lib_name /* = "mlx" */) {
|
const std::string& lib_name /* = "mlx" */) {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
// Look for cached kernel
|
// Look for cached kernel
|
||||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
@ -277,18 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Device& device(mlx::core::Device) {
|
Device& device(mlx::core::Device) {
|
||||||
static Device metal_device_;
|
static Device metal_device;
|
||||||
return metal_device_;
|
return metal_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
NS::AutoreleasePool*& thread_autorelease_pool() {
|
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||||
static thread_local NS::AutoreleasePool* p =
|
auto dtor = [](void* ptr) {
|
||||||
NS::AutoreleasePool::alloc()->init();
|
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||||
return p;
|
};
|
||||||
|
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream stream) {
|
void new_stream(Stream stream) {
|
||||||
thread_autorelease_pool();
|
|
||||||
if (stream.device == mlx::core::Device::gpu) {
|
if (stream.device == mlx::core::Device::gpu) {
|
||||||
device(stream.device).new_queue(stream.index);
|
device(stream.device).new_queue(stream.index);
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,6 @@ class Device {
|
|||||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NS::AutoreleasePool* pool_;
|
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||||
@ -78,6 +77,5 @@ class Device {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Device& device(mlx::core::Device);
|
Device& device(mlx::core::Device);
|
||||||
NS::AutoreleasePool*& thread_autorelease_pool();
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -112,80 +112,22 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// General reduce
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void general_reduce(
|
|
||||||
const device T *in [[buffer(0)]],
|
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const device int *in_shape [[buffer(2)]],
|
|
||||||
const device size_t *in_strides [[buffer(3)]],
|
|
||||||
const device size_t *out_strides [[buffer(4)]],
|
|
||||||
const device size_t& ndim [[buffer(5)]],
|
|
||||||
uint gid [[thread_position_in_grid]]) {
|
|
||||||
Op op;
|
|
||||||
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
|
|
||||||
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
|
|
||||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM>
|
|
||||||
[[kernel]] void general_reduce(
|
|
||||||
const device T *in [[buffer(0)]],
|
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const device int *in_shape [[buffer(2)]],
|
|
||||||
const device size_t *in_strides [[buffer(3)]],
|
|
||||||
const device size_t *out_strides [[buffer(4)]],
|
|
||||||
uint gid [[thread_position_in_grid]]) {
|
|
||||||
Op op;
|
|
||||||
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
|
|
||||||
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
|
|
||||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_general_reduce_helper(name, itype, otype, op) \
|
|
||||||
template [[host_name("general_reduce_" #name)]] \
|
|
||||||
[[kernel]] void general_reduce<itype, otype, op>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const device int *in_shape [[buffer(2)]], \
|
|
||||||
const device size_t *in_strides [[buffer(3)]], \
|
|
||||||
const device size_t *out_strides [[buffer(4)]], \
|
|
||||||
const device size_t& ndim [[buffer(5)]], \
|
|
||||||
uint gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
|
|
||||||
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
|
|
||||||
[[kernel]] void general_reduce<itype, otype, op, n>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const device int *in_shape [[buffer(2)]], \
|
|
||||||
const device size_t *in_strides [[buffer(3)]], \
|
|
||||||
const device size_t *out_strides [[buffer(4)]], \
|
|
||||||
uint gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_general_reduce(name, itype, otype, op) \
|
|
||||||
instantiate_general_reduce_helper(name, itype, otype, op) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
|
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Row atomics
|
// Row atomics
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
[[kernel]] void row_reduce(
|
[[kernel]] void row_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T *in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
const device size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
const constant size_t& out_size [[buffer(3)]],
|
||||||
uint lsize [[threads_per_threadgroup]],
|
const constant int* shape [[buffer(4)]],
|
||||||
uint tid [[threadgroup_position_in_grid]],
|
const constant size_t* strides [[buffer(5)]],
|
||||||
|
const constant int& ndim [[buffer(6)]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -193,7 +135,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
// Each threadgroup handles 1 reduction
|
// Each threadgroup handles 1 reduction
|
||||||
in += tid * reduction_size + lid * N_READS;
|
// TODO: Specializing elem_to_loc would be slightly faster
|
||||||
|
int idx = tid.y * out_size + tid.x;
|
||||||
|
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||||
|
in += extra_offset + lid.x * N_READS;
|
||||||
|
|
||||||
// The reduction is accumulated here
|
// The reduction is accumulated here
|
||||||
U total_val = Op::init;
|
U total_val = Op::init;
|
||||||
@ -201,7 +146,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
|
|
||||||
// Loop over the reduction size within thread group
|
// Loop over the reduction size within thread group
|
||||||
int r = 0;
|
int r = 0;
|
||||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
|
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
for(int i = 0; i < N_READS; i++) {
|
for(int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = in[i];
|
vals[i] = in[i];
|
||||||
@ -210,11 +155,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
in += lsize * N_READS;
|
in += lsize.x * N_READS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sepate case for the last set as we close the reduction size
|
// Separate case for the last set as we close the reduction size
|
||||||
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
|
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
|
||||||
if(reduction_index < reduction_size) {
|
if(reduction_index < reduction_size) {
|
||||||
int max_reads = reduction_size - reduction_index;
|
int max_reads = reduction_size - reduction_index;
|
||||||
|
|
||||||
@ -240,24 +185,28 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
// Reduction within thread group
|
// Reduction within thread group
|
||||||
// Only needed if multiple simd groups
|
// Only needed if multiple simd groups
|
||||||
if(reduction_size > simd_size) {
|
if(reduction_size > simd_size) {
|
||||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
}
|
}
|
||||||
// Update output
|
// Update output
|
||||||
if (lid == 0) {
|
if (lid.x == 0) {
|
||||||
out[tid] = total_val;
|
op.atomic_update(out, total_val, tid.x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_row_reduce(name, itype, otype, op) \
|
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
template [[host_name("row_reduce_" #name)]] \
|
template [[host_name("row_reduce_general_" #name)]] \
|
||||||
[[kernel]] void row_reduce<itype, otype, op>( \
|
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype *in [[buffer(0)]], \
|
||||||
device otype *out [[buffer(1)]], \
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
const device size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
const constant size_t& out_size [[buffer(3)]], \
|
||||||
uint lsize [[threads_per_threadgroup]], \
|
const constant int* shape [[buffer(4)]], \
|
||||||
uint tid [[threadgroup_position_in_grid]], \
|
const constant size_t* strides [[buffer(5)]], \
|
||||||
|
const constant int& ndim [[buffer(6)]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint3 lsize [[threads_per_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
@ -311,62 +260,26 @@ inline void _contiguous_strided_reduce(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void col_reduce(
|
[[kernel]] void col_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T *in [[buffer(0)]],
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
|
const constant int* shape [[buffer(5)]],
|
||||||
|
const constant size_t* strides [[buffer(6)]],
|
||||||
|
const constant int& ndim [[buffer(7)]],
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
threadgroup U *local_data [[threadgroup(0)]],
|
||||||
uint2 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint2 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint2 lsize [[threads_per_threadgroup]]) {
|
uint3 lsize [[threads_per_threadgroup]]) {
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
|
auto in_idx = elem_to_loc(
|
||||||
if(out_idx < out_size) {
|
out_idx + tid.z * out_size,
|
||||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
shape,
|
||||||
in,
|
strides,
|
||||||
out,
|
ndim
|
||||||
local_data,
|
);
|
||||||
out_idx,
|
|
||||||
out_idx,
|
|
||||||
reduction_size,
|
|
||||||
reduction_stride,
|
|
||||||
tid,
|
|
||||||
lid,
|
|
||||||
lsize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_col_reduce(name, itype, otype, op) \
|
|
||||||
template [[host_name("col_reduce_" #name)]] \
|
|
||||||
[[kernel]] void col_reduce<itype, otype, op>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]);
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
|
||||||
[[kernel]] void contiguous_strided_reduce(
|
|
||||||
const device T *in [[buffer(0)]],
|
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
|
||||||
const device int* in_shape [[buffer(5)]],
|
|
||||||
const device size_t* in_strides [[buffer(6)]],
|
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]) {
|
|
||||||
|
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
|
||||||
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
|
|
||||||
|
|
||||||
if(out_idx < out_size) {
|
if(out_idx < out_size) {
|
||||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
@ -377,82 +290,27 @@ template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
|||||||
out_idx,
|
out_idx,
|
||||||
reduction_size,
|
reduction_size,
|
||||||
reduction_stride,
|
reduction_stride,
|
||||||
tid,
|
tid.xy,
|
||||||
lid,
|
lid.xy,
|
||||||
lsize);
|
lsize.xy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||||
[[kernel]] void contiguous_strided_reduce(
|
template [[host_name("col_reduce_general_" #name)]] \
|
||||||
const device T *in [[buffer(0)]],
|
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
|
||||||
const device int* in_shape [[buffer(5)]],
|
|
||||||
const device size_t* in_strides [[buffer(6)]],
|
|
||||||
const device size_t& in_dim [[buffer(7)]],
|
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]) {
|
|
||||||
|
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
|
||||||
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
|
|
||||||
|
|
||||||
if(out_idx < out_size) {
|
|
||||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
local_data,
|
|
||||||
in_idx,
|
|
||||||
out_idx,
|
|
||||||
reduction_size,
|
|
||||||
reduction_stride,
|
|
||||||
tid,
|
|
||||||
lid,
|
|
||||||
lsize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
|
||||||
template [[host_name("contiguous_strided_reduce_" #name)]] \
|
|
||||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype *in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
const device int* in_shape [[buffer(5)]], \
|
const constant int* shape [[buffer(5)]], \
|
||||||
const device size_t* in_strides [[buffer(6)]], \
|
const constant size_t* strides [[buffer(6)]], \
|
||||||
const device size_t& in_dim [[buffer(7)]], \
|
const constant int& ndim [[buffer(7)]], \
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||||
uint2 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint2 lid [[thread_position_in_threadgroup]], \
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint2 lsize [[threads_per_threadgroup]]);
|
uint3 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
|
|
||||||
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
|
|
||||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
|
||||||
const device int* in_shape [[buffer(5)]], \
|
|
||||||
const device size_t* in_strides [[buffer(6)]], \
|
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_contiguous_strided(name, itype, otype, op) \
|
|
||||||
instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
|
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -461,10 +319,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
|
|
||||||
#define instantiate_reduce(name, itype, otype, op) \
|
#define instantiate_reduce(name, itype, otype, op) \
|
||||||
instantiate_all_reduce(name, itype, otype, op) \
|
instantiate_all_reduce(name, itype, otype, op) \
|
||||||
instantiate_row_reduce(name, itype, otype, op) \
|
instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
instantiate_col_reduce(name, itype, otype, op) \
|
instantiate_col_reduce_general(name, itype, otype, op)
|
||||||
instantiate_contiguous_strided(name, itype, otype, op) \
|
|
||||||
instantiate_general_reduce(name, itype, otype, op)
|
|
||||||
|
|
||||||
#define instantiate_same_reduce(name, tname, type, op) \
|
#define instantiate_same_reduce(name, tname, type, op) \
|
||||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||||
|
@ -50,6 +50,7 @@ std::function<void()> make_task(
|
|||||||
bool retain_graph) {
|
bool retain_graph) {
|
||||||
auto task =
|
auto task =
|
||||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
for (auto& d : deps) {
|
for (auto& d : deps) {
|
||||||
d.wait();
|
d.wait();
|
||||||
}
|
}
|
||||||
@ -66,12 +67,6 @@ std::function<void()> make_task(
|
|||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
p->set_value();
|
p->set_value();
|
||||||
// Signal this thread to clear the pool on a synchroniztion.
|
|
||||||
scheduler::enqueue(s, []() {
|
|
||||||
thread_autorelease_pool()->release();
|
|
||||||
thread_autorelease_pool() =
|
|
||||||
NS::AutoreleasePool::alloc()->init();
|
|
||||||
});
|
|
||||||
scheduler::notify_task_completion(s);
|
scheduler::notify_task_completion(s);
|
||||||
});
|
});
|
||||||
metal::device(s.device).commit_command_buffer(s.index);
|
metal::device(s.device).commit_command_buffer(s.index);
|
||||||
|
@ -20,6 +20,7 @@ constexpr bool is_available() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream stream);
|
void new_stream(Stream stream);
|
||||||
|
std::shared_ptr<void> new_scoped_memory_pool();
|
||||||
|
|
||||||
std::function<void()> make_task(
|
std::function<void()> make_task(
|
||||||
array& arr,
|
array& arr,
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
@ -61,22 +63,47 @@ void all_reduce_dispatch(
|
|||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void row_reduce_dispatch(
|
void row_reduce_general_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
const std::vector<int>& axes_,
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes,
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
metal::Device& d) {
|
metal::Device& d) {
|
||||||
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
|
auto kernel =
|
||||||
|
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
|
// Prepare the arguments for the kernel
|
||||||
int n_reads = REDUCE_N_READS;
|
int n_reads = REDUCE_N_READS;
|
||||||
size_t reduction_size = in.size() / out.size();
|
size_t reduction_size = plan.shape.back();
|
||||||
|
size_t out_size = out.size();
|
||||||
|
auto shape = plan.shape;
|
||||||
|
auto strides = plan.strides;
|
||||||
|
shape.pop_back();
|
||||||
|
strides.pop_back();
|
||||||
|
size_t non_row_reductions = 1;
|
||||||
|
for (auto s : shape) {
|
||||||
|
non_row_reductions *= static_cast<size_t>(s);
|
||||||
|
}
|
||||||
|
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||||
|
for (auto s : rem_shape) {
|
||||||
|
shape.push_back(s);
|
||||||
|
}
|
||||||
|
for (auto s : rem_strides) {
|
||||||
|
strides.push_back(s);
|
||||||
|
}
|
||||||
|
int ndim = shape.size();
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
|
||||||
// Each thread group is responsible for 1 output
|
// Each thread group is responsible for 1 output
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
@ -91,92 +118,54 @@ void row_reduce_dispatch(
|
|||||||
|
|
||||||
// Launch enough thread groups for each output
|
// Launch enough thread groups for each output
|
||||||
size_t n_threads = out.size() * thread_group_size;
|
size_t n_threads = out.size() * thread_group_size;
|
||||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void col_reduce_dispatch(
|
void strided_reduce_general_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
const std::vector<int>& axes_,
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes,
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
metal::Device& d) {
|
metal::Device& d) {
|
||||||
std::ostringstream kernel_name;
|
auto kernel =
|
||||||
|
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
bool encode_in_shape = false;
|
// Prepare the arguments for the kernel
|
||||||
bool encode_ndim = false;
|
size_t reduction_size = plan.shape.back();
|
||||||
|
size_t reduction_stride = plan.strides.back();
|
||||||
// If the slowest moving axis can be merged into the reductions,
|
|
||||||
// we call the column reduce kernel
|
|
||||||
// In this case, a linear index in the output corresponds to the
|
|
||||||
// linear index in the input where the reduction starts
|
|
||||||
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
|
|
||||||
kernel_name << "col_reduce_" << op_name << type_to_name(in);
|
|
||||||
}
|
|
||||||
// Otherwise, while all the reduction axes can be merged, the mapping between
|
|
||||||
// indices in the output and input require resolving using shapes and strides
|
|
||||||
else {
|
|
||||||
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
|
|
||||||
encode_in_shape = true;
|
|
||||||
|
|
||||||
// We check for a viable template with the required number of dimensions
|
|
||||||
// we only care about encoding non-reduced shapes and strides in the input
|
|
||||||
size_t non_reducing_dims = in.ndim() - axes_.size();
|
|
||||||
if (non_reducing_dims >= 1 &&
|
|
||||||
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
|
||||||
kernel_name << "_dim_" << non_reducing_dims;
|
|
||||||
} else {
|
|
||||||
encode_ndim = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto kernel = d.get_kernel(kernel_name.str());
|
|
||||||
size_t in_size = in.size();
|
|
||||||
size_t out_size = out.size();
|
size_t out_size = out.size();
|
||||||
|
auto shape = plan.shape;
|
||||||
|
auto strides = plan.strides;
|
||||||
|
shape.pop_back();
|
||||||
|
strides.pop_back();
|
||||||
|
size_t non_col_reductions = 1;
|
||||||
|
for (auto s : shape) {
|
||||||
|
non_col_reductions *= static_cast<size_t>(s);
|
||||||
|
}
|
||||||
|
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||||
|
for (auto s : rem_shape) {
|
||||||
|
shape.push_back(s);
|
||||||
|
}
|
||||||
|
for (auto s : rem_strides) {
|
||||||
|
strides.push_back(s);
|
||||||
|
}
|
||||||
|
int ndim = shape.size();
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
|
||||||
// Calculate the number of inputs to reduce and the stride b/w them
|
|
||||||
size_t reduction_size = 1;
|
|
||||||
size_t in_ndim = in.ndim();
|
|
||||||
size_t reduction_stride = in_size;
|
|
||||||
|
|
||||||
for (int i : axes_) {
|
|
||||||
reduction_size *= in.shape(i);
|
|
||||||
reduction_stride = std::min(reduction_stride, in.strides()[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||||
if (encode_in_shape) {
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||||
// Obtain the non-reducing shape and strides of the input to encode
|
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
|
||||||
std::vector<int> inp_shape_mod;
|
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||||
std::vector<size_t> inp_strides_mod;
|
|
||||||
|
|
||||||
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
|
|
||||||
if (j < axes_.size() && axes_[j] == i) {
|
|
||||||
j++;
|
|
||||||
} else {
|
|
||||||
inp_shape_mod.push_back(in.shape(i));
|
|
||||||
inp_strides_mod.push_back(in.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ndim = inp_shape_mod.size();
|
|
||||||
|
|
||||||
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
|
|
||||||
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
|
|
||||||
|
|
||||||
if (encode_ndim) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select block dimensions
|
// Select block dimensions
|
||||||
|
|
||||||
@ -200,7 +189,8 @@ void col_reduce_dispatch(
|
|||||||
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||||
|
|
||||||
// Launch enough thread groups for each output
|
// Launch enough thread groups for each output
|
||||||
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
|
MTL::Size grid_dims =
|
||||||
|
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
|
||||||
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||||
|
|
||||||
// We set shared memory to be exploited here for reductions within a
|
// We set shared memory to be exploited here for reductions within a
|
||||||
@ -216,60 +206,6 @@ void col_reduce_dispatch(
|
|||||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void general_reduce_dispatch(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const std::string& op_name,
|
|
||||||
const std::vector<int>& axes_,
|
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
|
||||||
metal::Device& d) {
|
|
||||||
bool encode_ndim = true;
|
|
||||||
std::ostringstream kernel_name;
|
|
||||||
kernel_name << "general_reduce_" << op_name << type_to_name(in);
|
|
||||||
|
|
||||||
// Check for specialzed kernels for input ndim
|
|
||||||
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
|
||||||
kernel_name << "_dim_" << in.ndim();
|
|
||||||
encode_ndim = false;
|
|
||||||
}
|
|
||||||
auto kernel = d.get_kernel(kernel_name.str());
|
|
||||||
size_t in_size = in.size();
|
|
||||||
size_t ndim = in.ndim();
|
|
||||||
|
|
||||||
// We set the reducing strides to 0 to induce collisions for the reduction
|
|
||||||
std::vector<size_t> out_strides(ndim);
|
|
||||||
size_t stride = 1;
|
|
||||||
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
|
|
||||||
if (j >= 0 && axes_[j] == i) {
|
|
||||||
out_strides[i] = 0;
|
|
||||||
--j;
|
|
||||||
} else {
|
|
||||||
out_strides[i] = stride;
|
|
||||||
stride *= in.shape(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
|
||||||
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
|
|
||||||
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
|
|
||||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
|
||||||
if (encode_ndim) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > in_size) {
|
|
||||||
thread_group_size = in_size;
|
|
||||||
}
|
|
||||||
size_t nthreads = in_size;
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
@ -278,7 +214,7 @@ void general_reduce_dispatch(
|
|||||||
|
|
||||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
array in = inputs[0];
|
||||||
|
|
||||||
// TODO: Allow specific row and column reductions with types disabled
|
// TODO: Allow specific row and column reductions with types disabled
|
||||||
// due to atomics ?
|
// due to atomics ?
|
||||||
@ -335,37 +271,47 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Reduce
|
// Reduce
|
||||||
{
|
{
|
||||||
// Check for contiguous data
|
std::vector<array> copies;
|
||||||
if (in.size() == in.data_size() &&
|
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||||
(in.flags().row_contiguous || in.flags().col_contiguous)) {
|
|
||||||
// Go to all reduce if reducing over all axes
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
if (axes_.size() == in.ndim()) {
|
// recompute the plan.
|
||||||
|
if (plan.type == GeneralReduce) {
|
||||||
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
|
copies.push_back(in_copy);
|
||||||
|
in = in_copy;
|
||||||
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reducing over everything and the data is all there no broadcasting or
|
||||||
|
// slicing etc.
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
// Use specialized kernels if the input is row contiguous and
|
|
||||||
// the reducing axes can be merged into one
|
// At least the last dimension is row contiguous and we are reducing over
|
||||||
|
// the last dim.
|
||||||
else if (
|
else if (
|
||||||
in.flags().row_contiguous && in.strides().back() == 1 &&
|
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||||
(axes_.back() - axes_.front()) == axes_.size() - 1) {
|
row_reduce_general_dispatch(
|
||||||
// If the fastest moving axis is being reduced, go to row reduce
|
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||||
if (axes_[0] == (in.ndim() - axes_.size())) {
|
|
||||||
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
// Otherwise go to to generalized strided reduce
|
|
||||||
// Note: bool isn't support here yet due to the use of atomics
|
// At least the last two dimensions are contiguous and we are doing a
|
||||||
// once that is updated, this should be the else condition of this
|
// strided reduce over these.
|
||||||
// branch
|
else if (
|
||||||
else if (in.dtype() != bool_) {
|
plan.type == ContiguousStridedReduce ||
|
||||||
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
plan.type == GeneralStridedReduce) {
|
||||||
return;
|
strided_reduce_general_dispatch(
|
||||||
|
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!copies.empty()) {
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fall back to the general case
|
|
||||||
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -7,6 +7,9 @@
|
|||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
void new_stream(Stream) {}
|
void new_stream(Stream) {}
|
||||||
|
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
std::function<void()> make_task(
|
std::function<void()> make_task(
|
||||||
array& arr,
|
array& arr,
|
||||||
|
175
mlx/linalg.cpp
Normal file
175
mlx/linalg.cpp
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
#include <ostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/dtype.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
|
|
||||||
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
|
Dtype at_least_float(const Dtype& d) {
|
||||||
|
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array l2_norm(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (is_complex(a.dtype())) {
|
||||||
|
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||||
|
} else {
|
||||||
|
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array vector_norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto dtype = at_least_float(a.dtype());
|
||||||
|
if (ord == 0.0) {
|
||||||
|
return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
|
||||||
|
} else if (ord == 1.0) {
|
||||||
|
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else if (ord == 2.0) {
|
||||||
|
return l2_norm(a, axis, keepdims, s);
|
||||||
|
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||||
|
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||||
|
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else {
|
||||||
|
return power(
|
||||||
|
sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
|
||||||
|
array(1.0 / ord, dtype),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array matrix_norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto dtype = at_least_float(a.dtype());
|
||||||
|
auto row_axis = axis[0];
|
||||||
|
auto col_axis = axis[1];
|
||||||
|
if (ord == -1.0) {
|
||||||
|
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||||
|
return astype(
|
||||||
|
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == 1.0) {
|
||||||
|
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||||
|
return astype(
|
||||||
|
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||||
|
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||||
|
return astype(
|
||||||
|
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||||
|
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||||
|
return astype(
|
||||||
|
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == 2.0 || ord == -2.0) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[linalg::norm] Singular value norms are not implemented.");
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array matrix_norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (ord == "f" || ord == "fro") {
|
||||||
|
return l2_norm(a, axis, keepdims, s);
|
||||||
|
} else if (ord == "nuc") {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!axis) {
|
||||||
|
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (axis.value().size() > 2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[linalg::norm] Received too many axes for norm.");
|
||||||
|
}
|
||||||
|
return l2_norm(a, axis.value(), keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
std::vector<int> ax;
|
||||||
|
if (!axis) {
|
||||||
|
ax.resize(a.ndim());
|
||||||
|
std::iota(ax.begin(), ax.end(), 0);
|
||||||
|
} else {
|
||||||
|
ax = axis.value();
|
||||||
|
}
|
||||||
|
if (ax.size() == 1) {
|
||||||
|
return vector_norm(a, ord, ax, keepdims, s);
|
||||||
|
} else if (ax.size() == 2) {
|
||||||
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[linalg::norm] Received too many axes for norm.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
std::vector<int> ax;
|
||||||
|
if (!axis) {
|
||||||
|
ax.resize(a.ndim());
|
||||||
|
std::iota(ax.begin(), ax.end(), 0);
|
||||||
|
} else {
|
||||||
|
ax = axis.value();
|
||||||
|
}
|
||||||
|
if (ax.size() != 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
|
||||||
|
<< " but received " << ax.size() << " axis/axes.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::linalg
|
63
mlx/linalg.h
Normal file
63
mlx/linalg.h
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/device.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute vector or matrix norms.
|
||||||
|
*
|
||||||
|
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
||||||
|
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
||||||
|
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
|
||||||
|
* for matrices) is computed along the given axes. At most 2 axes can be
|
||||||
|
* specified.
|
||||||
|
* - If both axis and ord are provided, then the corresponding matrix or vector
|
||||||
|
* norm is computed. At most 2 axes can be specified.
|
||||||
|
*/
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array
|
||||||
|
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
||||||
|
return norm(a, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::linalg
|
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/random.h"
|
#include "mlx/random.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
@ -103,7 +103,9 @@ array uniform(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
auto range = subtract(high, low, stream);
|
auto lo = astype(low, dtype, stream);
|
||||||
|
auto hi = astype(high, dtype, stream);
|
||||||
|
auto range = subtract(hi, lo, stream);
|
||||||
auto out_shape = broadcast_shapes(shape, range.shape());
|
auto out_shape = broadcast_shapes(shape, range.shape());
|
||||||
if (out_shape != shape) {
|
if (out_shape != shape) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -136,7 +138,7 @@ array uniform(
|
|||||||
auto out = bits(shape, size_of(dtype), key, stream);
|
auto out = bits(shape, size_of(dtype), key, stream);
|
||||||
out = astype(divide(out, maxval, stream), dtype, stream);
|
out = astype(divide(out, maxval, stream), dtype, stream);
|
||||||
out = minimum(out, upper, stream);
|
out = minimum(out, upper, stream);
|
||||||
return add(multiply(range, out, stream), low, stream);
|
return add(multiply(range, out, stream), lo, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
array uniform(
|
array uniform(
|
||||||
|
@ -35,6 +35,7 @@ struct StreamThread {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void thread_fn() {
|
void thread_fn() {
|
||||||
|
auto thread_pool = metal::new_scoped_memory_pool();
|
||||||
metal::new_stream(stream);
|
metal::new_stream(stream);
|
||||||
while (true) {
|
while (true) {
|
||||||
std::function<void()> task;
|
std::function<void()> task;
|
||||||
|
@ -33,10 +33,16 @@ from mlx.nn.layers.activations import (
|
|||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.containers import Sequential
|
from mlx.nn.layers.containers import Sequential
|
||||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||||
from mlx.nn.layers.dropout import Dropout
|
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Linear
|
||||||
from mlx.nn.layers.normalization import GroupNorm, InstanceNorm, LayerNorm, RMSNorm
|
from mlx.nn.layers.normalization import (
|
||||||
|
BatchNorm,
|
||||||
|
GroupNorm,
|
||||||
|
InstanceNorm,
|
||||||
|
LayerNorm,
|
||||||
|
RMSNorm,
|
||||||
|
)
|
||||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
from mlx.nn.layers.transformer import (
|
from mlx.nn.layers.transformer import (
|
||||||
|
@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module
|
|||||||
|
|
||||||
|
|
||||||
class Dropout(Module):
|
class Dropout(Module):
|
||||||
"""Randomly zero a portion of the elements during training.
|
r"""Randomly zero a portion of the elements during training.
|
||||||
|
|
||||||
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
||||||
:math:`p` is the probability of zeroing an element. This is done so the
|
:math:`p` is the probability of zeroing an element. This is done so the
|
||||||
@ -32,4 +32,57 @@ class Dropout(Module):
|
|||||||
|
|
||||||
mask = mx.random.bernoulli(self._p_1, x.shape)
|
mask = mx.random.bernoulli(self._p_1, x.shape)
|
||||||
|
|
||||||
return (1 / self._p_1) * mask.astype(x.dtype) * x
|
return (1 / self._p_1) * mask * x
|
||||||
|
|
||||||
|
|
||||||
|
class Dropout2d(Module):
|
||||||
|
r"""Apply 2D channel-wise dropout during training.
|
||||||
|
|
||||||
|
Randomly zero out entire channels independently with probability :math:`p`.
|
||||||
|
This layer expects the channels to be last, i.e. the input shape should be
|
||||||
|
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
|
||||||
|
image height,``W`` is the input image width, and``C`` is the number of
|
||||||
|
input channels
|
||||||
|
|
||||||
|
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
||||||
|
maintain the expected value of each element. Unlike traditional dropout,
|
||||||
|
which zeros individual entries, this layer zeros entire channels. This is
|
||||||
|
beneficial for early convolution layers where adjacent pixels are
|
||||||
|
correlated. In such case, traditional dropout may not effectively
|
||||||
|
regularize activations. For more details, see [1].
|
||||||
|
|
||||||
|
[1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
|
||||||
|
Efficient Object Localization Using Convolutional Networks. CVPR 2015.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p (float): Probability of zeroing a channel during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p: float = 0.5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if p < 0 or p >= 1:
|
||||||
|
raise ValueError("The dropout probability should be in [0, 1)")
|
||||||
|
|
||||||
|
self._p_1 = 1 - p
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"p={1-self._p_1}"
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if x.ndim not in (3, 4):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._p_1 == 1 or not self.training:
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Dropout is applied on the whole channel
|
||||||
|
# 3D input: (1, 1, C)
|
||||||
|
# 4D input: (B, 1, 1, C)
|
||||||
|
mask_shape = x.shape
|
||||||
|
mask_shape[-2] = mask_shape[-3] = 1
|
||||||
|
|
||||||
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||||
|
return (1 / self._p_1) * mask * x
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
@ -252,3 +254,121 @@ class GroupNorm(Module):
|
|||||||
)
|
)
|
||||||
x = group_norm(x)
|
x = group_norm(x)
|
||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNorm(Module):
|
||||||
|
r"""Applies Batch Normalization over a 2D or 3D input.
|
||||||
|
|
||||||
|
Computes
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||||
|
|
||||||
|
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||||
|
parameters initialized at 1 and 0 respectively.
|
||||||
|
|
||||||
|
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
|
||||||
|
batch, ``C`` is the number of features or channels, and ``L`` is the
|
||||||
|
sequence length. The output has the same shape as the input. For
|
||||||
|
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
|
||||||
|
the height and width respecitvely.
|
||||||
|
|
||||||
|
For more information on Batch Normalization, see the original paper `Batch
|
||||||
|
Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||||
|
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features (int): The feature dimension to normalize over.
|
||||||
|
eps (float, optional): A small additive constant for numerical
|
||||||
|
stability. Default: ``1e-5``.
|
||||||
|
momentum (float, optional): The momentum for updating the running
|
||||||
|
mean and variance. Default: ``0.1``.
|
||||||
|
affine (bool, optional): If ``True``, apply a learned affine
|
||||||
|
transformation after the normalization. Default: ``True``.
|
||||||
|
track_running_stats (bool, optional): If ``True``, track the
|
||||||
|
running mean and variance. Default: ``True``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> import mlx.nn as nn
|
||||||
|
>>> x = mx.random.normal((5, 4))
|
||||||
|
>>> bn = nn.BatchNorm(num_features=4, affine=True)
|
||||||
|
>>> output = bn(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
affine: bool = True,
|
||||||
|
track_running_stats: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.momentum = momentum
|
||||||
|
self.track_running_stats = track_running_stats
|
||||||
|
|
||||||
|
if affine:
|
||||||
|
self.weight = mx.ones((num_features,))
|
||||||
|
self.bias = mx.zeros((num_features,))
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
|
self._running_mean = mx.zeros((num_features,))
|
||||||
|
self._running_var = mx.ones((num_features,))
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.num_features}, eps={self.eps}, "
|
||||||
|
f"momentum={self.momentum}, affine={'weight' in self}, "
|
||||||
|
f"track_running_stats={self.track_running_stats}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""
|
||||||
|
Calculate the mean and variance of the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing mean and variance.
|
||||||
|
"""
|
||||||
|
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||||
|
means = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||||
|
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||||
|
|
||||||
|
if self.track_running_stats and self.training:
|
||||||
|
self._running_mean = (
|
||||||
|
1 - self.momentum
|
||||||
|
) * self._running_mean + self.momentum * means
|
||||||
|
self._running_var = (
|
||||||
|
1 - self.momentum
|
||||||
|
) * self._running_var + self.momentum * var
|
||||||
|
return means, var
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass of BatchNorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: Output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if x.ndim < 2 or x.ndim > 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.training or not self.track_running_stats:
|
||||||
|
means, var = self._calc_stats(x)
|
||||||
|
else:
|
||||||
|
means, var = self._running_mean, self._running_var
|
||||||
|
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||||
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
@ -131,10 +133,6 @@ def mse_loss(
|
|||||||
f"targets shape {targets.shape}."
|
f"targets shape {targets.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
|
||||||
predictions.shape == targets.shape
|
|
||||||
), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match"
|
|
||||||
|
|
||||||
loss = mx.square(predictions - targets)
|
loss = mx.square(predictions - targets)
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
@ -283,3 +281,94 @@ def _reduce(loss: mx.array, reduction: str = "none"):
|
|||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||||
|
|
||||||
|
|
||||||
|
def hinge_loss(
|
||||||
|
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
|
r"""
|
||||||
|
Computes the hinge loss between inputs and targets.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (array): The predicted values.
|
||||||
|
targets (array): The target values. They should be -1 or 1.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The computed hinge loss.
|
||||||
|
"""
|
||||||
|
loss = mx.maximum(1 - inputs * targets, 0)
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def huber_loss(
|
||||||
|
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
|
r"""
|
||||||
|
Computes the Huber loss between inputs and targets.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
L_{\delta}(a) =
|
||||||
|
\left\{ \begin{array}{ll}
|
||||||
|
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
|
||||||
|
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
|
||||||
|
\end{array} \right.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (array): The predicted values.
|
||||||
|
targets (array): The target values.
|
||||||
|
delta (float, optional): The threshold at which to change between L1 and L2 loss.
|
||||||
|
Default: ``1.0``.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The computed Huber loss.
|
||||||
|
"""
|
||||||
|
errors = inputs - targets
|
||||||
|
abs_errors = mx.abs(errors)
|
||||||
|
quadratic = mx.minimum(abs_errors, delta)
|
||||||
|
linear = abs_errors - quadratic
|
||||||
|
loss = 0.5 * quadratic**2 + delta * linear
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def log_cosh_loss(
|
||||||
|
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
|
r"""
|
||||||
|
Computes the log cosh loss between inputs and targets.
|
||||||
|
|
||||||
|
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
|
||||||
|
and like the L1 loss for large errors, reducing sensitivity to outliers. This
|
||||||
|
dual behavior offers a balanced, robust approach for regression tasks.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{logcosh}(y_{\text{true}}, y_{\text{pred}}) =
|
||||||
|
\frac{1}{n} \sum_{i=1}^{n}
|
||||||
|
\log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (array): The predicted values.
|
||||||
|
targets (array): The target values.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The computed log cosh loss.
|
||||||
|
"""
|
||||||
|
errors = inputs - targets
|
||||||
|
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
@ -11,6 +11,7 @@ pybind11_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||||
|
@ -510,6 +510,14 @@ void init_array(py::module_& m) {
|
|||||||
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
|
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
|
||||||
|
.def_property_readonly(
|
||||||
|
"itemsize",
|
||||||
|
&array::itemsize,
|
||||||
|
R"pbdoc(The size of the array's datatype in bytes.)pbdoc")
|
||||||
|
.def_property_readonly(
|
||||||
|
"nbytes",
|
||||||
|
&array::nbytes,
|
||||||
|
R"pbdoc(The number of bytes in the array.)pbdoc")
|
||||||
// TODO, this makes a deep copy of the shape
|
// TODO, this makes a deep copy of the shape
|
||||||
// implement alternatives to use reference
|
// implement alternatives to use reference
|
||||||
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
||||||
|
180
python/src/linalg.cpp
Normal file
180
python/src/linalg.cpp
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "mlx/linalg.h"
|
||||||
|
|
||||||
|
#include "python/src/load.h"
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
using namespace mlx::core::linalg;
|
||||||
|
|
||||||
|
void init_linalg(py::module_& parent_module) {
|
||||||
|
py::options options;
|
||||||
|
options.disable_function_signatures();
|
||||||
|
|
||||||
|
auto m = parent_module.def_submodule(
|
||||||
|
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
|
const std::variant<std::monostate, int, double, std::string>& ord_,
|
||||||
|
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
||||||
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
|
std::optional<std::vector<int>> axis = std::nullopt;
|
||||||
|
if (auto pv = std::get_if<int>(&axis_); pv) {
|
||||||
|
axis = std::vector<int>{*pv};
|
||||||
|
} else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
|
||||||
|
axis = *pv;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::holds_alternative<std::monostate>(ord_)) {
|
||||||
|
return norm(a, axis, keepdims, stream);
|
||||||
|
} else {
|
||||||
|
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
||||||
|
return norm(a, *pv, axis, keepdims, stream);
|
||||||
|
}
|
||||||
|
double ord;
|
||||||
|
if (auto pv = std::get_if<int>(&ord_); pv) {
|
||||||
|
ord = *pv;
|
||||||
|
} else {
|
||||||
|
ord = std::get<double>(ord_);
|
||||||
|
}
|
||||||
|
return norm(a, ord, axis, keepdims, stream);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"ord"_a = none,
|
||||||
|
"axis"_a = none,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Matrix or vector norm.
|
||||||
|
|
||||||
|
This function computes vector or matrix norms depending on the value of
|
||||||
|
the ``ord`` and ``axis`` parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
|
||||||
|
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||||
|
2-norm of ``a.flatten`` will be returned.
|
||||||
|
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||||
|
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
|
||||||
|
along the given ``axis``. Default: ``None``.
|
||||||
|
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||||
|
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
|
||||||
|
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
|
||||||
|
norms of these matrices are computed. If `axis` is ``None`` then
|
||||||
|
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
|
||||||
|
2-D) is returned. Default: ``None``.
|
||||||
|
keepdims (bool, optional): If ``True``, the axes which are normed over are
|
||||||
|
left in the result as dimensions with size one. Default ``False``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output containing the norm(s).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
For values of ``ord < 1``, the result is, strictly speaking, not a
|
||||||
|
mathematical norm, but it may still be useful for various numerical
|
||||||
|
purposes.
|
||||||
|
|
||||||
|
The following norms can be calculated:
|
||||||
|
|
||||||
|
===== ============================ ==========================
|
||||||
|
ord norm for matrices norm for vectors
|
||||||
|
===== ============================ ==========================
|
||||||
|
None Frobenius norm 2-norm
|
||||||
|
'fro' Frobenius norm --
|
||||||
|
inf max(sum(abs(x), axis=1)) max(abs(x))
|
||||||
|
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
||||||
|
0 -- sum(x != 0)
|
||||||
|
1 max(sum(abs(x), axis=0)) as below
|
||||||
|
-1 min(sum(abs(x), axis=0)) as below
|
||||||
|
2 2-norm (largest sing. value) as below
|
||||||
|
-2 smallest singular value as below
|
||||||
|
other -- sum(abs(x)**ord)**(1./ord)
|
||||||
|
===== ============================ ==========================
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
Nuclear norm and norms based on singular values are not yet implemented.
|
||||||
|
|
||||||
|
The Frobenius norm is given by [1]_:
|
||||||
|
|
||||||
|
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||||
|
|
||||||
|
The nuclear norm is the sum of the singular values.
|
||||||
|
|
||||||
|
Both the Frobenius and nuclear norm orders are only defined for
|
||||||
|
matrices and raise a ``ValueError`` when ``a.ndim != 2``.
|
||||||
|
|
||||||
|
References:
|
||||||
|
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
|
||||||
|
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> from mlx.core import linalg as la
|
||||||
|
>>> a = mx.arange(9) - 4
|
||||||
|
>>> a
|
||||||
|
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
|
||||||
|
>>> b = a.reshape((3,3))
|
||||||
|
>>> b
|
||||||
|
array([[-4, -3, -2],
|
||||||
|
[-1, 0, 1],
|
||||||
|
[ 2, 3, 4]], dtype=int32)
|
||||||
|
>>> la.norm(a)
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(b)
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(b, 'fro')
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(a, float("inf"))
|
||||||
|
array(4, dtype=float32)
|
||||||
|
>>> la.norm(b, float("inf"))
|
||||||
|
array(9, dtype=float32)
|
||||||
|
>>> la.norm(a, -float("inf"))
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>>> la.norm(b, -float("inf"))
|
||||||
|
array(2, dtype=float32)
|
||||||
|
>>> la.norm(a, 1)
|
||||||
|
array(20, dtype=float32)
|
||||||
|
>>> la.norm(b, 1)
|
||||||
|
array(7, dtype=float32)
|
||||||
|
>>> la.norm(a, -1)
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>>> la.norm(b, -1)
|
||||||
|
array(6, dtype=float32)
|
||||||
|
>>> la.norm(a, 2)
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(a, 3)
|
||||||
|
array(5.84804, dtype=float32)
|
||||||
|
>>> la.norm(a, -3)
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>>> c = mx.array([[ 1, 2, 3],
|
||||||
|
... [-1, 1, 4]])
|
||||||
|
>>> la.norm(c, axis=0)
|
||||||
|
array([1.41421, 2.23607, 5], dtype=float32)
|
||||||
|
>>> la.norm(c, axis=1)
|
||||||
|
array([3.74166, 4.24264], dtype=float32)
|
||||||
|
>>> la.norm(c, ord=1, axis=1)
|
||||||
|
array([6, 6], dtype=float32)
|
||||||
|
>>> m = mx.arange(8).reshape(2,2,2)
|
||||||
|
>>> la.norm(m, axis=(1,2))
|
||||||
|
array([3.74166, 11.225], dtype=float32)
|
||||||
|
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
|
||||||
|
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -15,6 +15,7 @@ void init_ops(py::module_&);
|
|||||||
void init_transforms(py::module_&);
|
void init_transforms(py::module_&);
|
||||||
void init_random(py::module_&);
|
void init_random(py::module_&);
|
||||||
void init_fft(py::module_&);
|
void init_fft(py::module_&);
|
||||||
|
void init_linalg(py::module_&);
|
||||||
|
|
||||||
PYBIND11_MODULE(core, m) {
|
PYBIND11_MODULE(core, m) {
|
||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) {
|
|||||||
init_transforms(m);
|
init_transforms(m);
|
||||||
init_random(m);
|
init_random(m);
|
||||||
init_fft(m);
|
init_fft(m);
|
||||||
|
init_linalg(m);
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
}
|
}
|
||||||
|
@ -2129,7 +2129,7 @@ void init_ops(py::module_& m) {
|
|||||||
singleton dimensions, defaults to `False`.
|
singleton dimensions, defaults to `False`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array with the indices of the minimum values.
|
array: The output array with the indices of the maximum values.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sort",
|
"sort",
|
||||||
|
@ -569,7 +569,7 @@ void init_transforms(py::module_& m) {
|
|||||||
return lvalue
|
return lvalue
|
||||||
|
|
||||||
# Returns lvalue, dlvalue/dparams
|
# Returns lvalue, dlvalue/dparams
|
||||||
lvalue, grads = mx.value_and_grad(mse)
|
lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)
|
||||||
|
|
||||||
def lasso(params, inputs, targets, a=1.0, b=1.0):
|
def lasso(params, inputs, targets, a=1.0, b=1.0):
|
||||||
outputs = forward(params, inputs)
|
outputs = forward(params, inputs)
|
||||||
@ -580,7 +580,7 @@ void init_transforms(py::module_& m) {
|
|||||||
|
|
||||||
return loss, mse, l1
|
return loss, mse, l1
|
||||||
|
|
||||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)
|
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fun (function): A function which takes a variable number of
|
fun (function): A function which takes a variable number of
|
||||||
|
@ -84,6 +84,8 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
x = mx.array(1)
|
x = mx.array(1)
|
||||||
self.assertEqual(x.size, 1)
|
self.assertEqual(x.size, 1)
|
||||||
self.assertEqual(x.ndim, 0)
|
self.assertEqual(x.ndim, 0)
|
||||||
|
self.assertEqual(x.itemsize, 4)
|
||||||
|
self.assertEqual(x.nbytes, 4)
|
||||||
self.assertEqual(x.shape, [])
|
self.assertEqual(x.shape, [])
|
||||||
self.assertEqual(x.dtype, mx.int32)
|
self.assertEqual(x.dtype, mx.int32)
|
||||||
self.assertEqual(x.item(), 1)
|
self.assertEqual(x.item(), 1)
|
||||||
|
94
python/tests/test_linalg.py
Normal file
94
python/tests/test_linalg.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestLinalg(mlx_tests.MLXTestCase):
|
||||||
|
def test_norm(self):
|
||||||
|
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||||
|
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||||
|
|
||||||
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
# Test when at least one axis is provided
|
||||||
|
for num_axes in range(1, len(shape)):
|
||||||
|
if num_axes == 1:
|
||||||
|
ords = vector_ords
|
||||||
|
else:
|
||||||
|
ords = matrix_ords
|
||||||
|
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
for o in ords:
|
||||||
|
out_np = np.linalg.norm(
|
||||||
|
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
out_mx = mx.linalg.norm(
|
||||||
|
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
with self.subTest(
|
||||||
|
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test only ord provided
|
||||||
|
for shape in [(3,), (2, 3)]:
|
||||||
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
||||||
|
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test no ord and no axis provided
|
||||||
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||||
|
with self.subTest(shape=shape, keepdims=keepdims):
|
||||||
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
|
||||||
|
def test_complex_norm(self):
|
||||||
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
|
x_np = np.random.uniform(size=shape).astype(
|
||||||
|
np.float32
|
||||||
|
) + 1j * np.random.uniform(size=shape).astype(np.float32)
|
||||||
|
x_mx = mx.array(x_np)
|
||||||
|
out_np = np.linalg.norm(x_np)
|
||||||
|
out_mx = mx.linalg.norm(x_mx)
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
for num_axes in range(1, len(shape)):
|
||||||
|
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||||
|
out_np = np.linalg.norm(x_np, axis=axis)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, axis=axis)
|
||||||
|
with self.subTest(shape=shape, axis=axis):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_np = np.random.uniform(size=(4, 4)).astype(
|
||||||
|
np.float32
|
||||||
|
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
|
||||||
|
x_mx = mx.array(x_np)
|
||||||
|
out_np = np.linalg.norm(x_np, ord="fro")
|
||||||
|
out_mx = mx.linalg.norm(x_mx, ord="fro")
|
||||||
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -511,6 +511,143 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(x.shape == y.shape)
|
self.assertTrue(x.shape == y.shape)
|
||||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
|
def test_batch_norm(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
x = mx.random.normal((5, 4), dtype=mx.float32)
|
||||||
|
|
||||||
|
# Batch norm
|
||||||
|
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
||||||
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
||||||
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
||||||
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
||||||
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
|
||||||
|
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||||
|
|
||||||
|
# test eval mode
|
||||||
|
bn.eval()
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.15984, 1.73159, -1.25456, 1.57891],
|
||||||
|
[-0.872193, -1.4281, -0.414439, -0.228678],
|
||||||
|
[0.602743, -0.30566, -0.554687, 0.139639],
|
||||||
|
[0.252199, 0.29066, -0.599572, -0.0512532],
|
||||||
|
[0.594096, -0.0334829, 2.11359, -0.151081],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
|
# test_no_affine
|
||||||
|
bn = nn.BatchNorm(num_features=4, affine=False)
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
||||||
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
||||||
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
||||||
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
||||||
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
|
# test with 3D input
|
||||||
|
mx.random.seed(42)
|
||||||
|
N = 2
|
||||||
|
L = 4
|
||||||
|
C = 5
|
||||||
|
x = mx.random.normal((N, L, C), dtype=mx.float32)
|
||||||
|
|
||||||
|
# Batch norm
|
||||||
|
bn = nn.BatchNorm(num_features=C, affine=True)
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||||
|
y = bn(x)
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
|
||||||
|
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
|
||||||
|
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
|
||||||
|
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
|
||||||
|
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
|
||||||
|
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
|
||||||
|
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
expected_mean = mx.array(
|
||||||
|
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
|
||||||
|
)
|
||||||
|
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||||
|
|
||||||
|
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = bn(x)
|
||||||
|
|
||||||
|
def test_batch_norm_stats(self):
|
||||||
|
batch_size = 2
|
||||||
|
num_features = 4
|
||||||
|
h = 3
|
||||||
|
w = 3
|
||||||
|
momentum = 0.1
|
||||||
|
|
||||||
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
|
batch_norm.train()
|
||||||
|
running_mean = np.array(batch_norm._running_mean)
|
||||||
|
running_var = np.array(batch_norm._running_var)
|
||||||
|
|
||||||
|
data = mx.random.normal((batch_size, num_features))
|
||||||
|
|
||||||
|
normalized_data = batch_norm(data)
|
||||||
|
np_data = np.array(data)
|
||||||
|
means = np.mean(np_data, axis=0)
|
||||||
|
variances = np.var(np_data, axis=0)
|
||||||
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||||
|
|
||||||
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
|
batch_norm.train()
|
||||||
|
running_mean = np.array(batch_norm._running_mean)
|
||||||
|
running_var = np.array(batch_norm._running_var)
|
||||||
|
data = mx.random.normal((batch_size, h, w, num_features))
|
||||||
|
|
||||||
|
normalized_data = batch_norm(data)
|
||||||
|
np_data = np.array(data)
|
||||||
|
means = np.mean(np_data, axis=(0, 1, 2))
|
||||||
|
variances = np.var(np_data, axis=(0, 1, 2))
|
||||||
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||||
|
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
N = 5
|
N = 5
|
||||||
L = 12
|
L = 12
|
||||||
@ -772,6 +909,24 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
y = alibi(x.astype(mx.float16))
|
y = alibi(x.astype(mx.float16))
|
||||||
self.assertTrue(y.dtype, mx.float16)
|
self.assertTrue(y.dtype, mx.float16)
|
||||||
|
|
||||||
|
def test_hinge_loss(self):
|
||||||
|
inputs = mx.ones((2, 4))
|
||||||
|
targets = mx.zeros((2, 4))
|
||||||
|
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
|
||||||
|
self.assertEqual(loss, 1.0)
|
||||||
|
|
||||||
|
def test_huber_loss(self):
|
||||||
|
inputs = mx.ones((2, 4))
|
||||||
|
targets = mx.zeros((2, 4))
|
||||||
|
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
|
||||||
|
self.assertEqual(loss, 0.5)
|
||||||
|
|
||||||
|
def test_log_cosh_loss(self):
|
||||||
|
inputs = mx.ones((2, 4))
|
||||||
|
targets = mx.zeros((2, 4))
|
||||||
|
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||||
|
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -13,7 +13,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
w_q, scales, biases = mx.quantize(w, 64, b)
|
w_q, scales, biases = mx.quantize(w, 64, b)
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
|
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
|
||||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||||
self.assertTrue((errors <= scales[..., None] / 2).all())
|
eps = 1e-6
|
||||||
|
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
|
||||||
|
|
||||||
def test_qmm(self):
|
def test_qmm(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
|
@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
||||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||||
|
|
||||||
|
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
|
||||||
|
self.assertEqual(a.dtype, mx.bfloat16)
|
||||||
|
|
||||||
def test_normal(self):
|
def test_normal(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
a = mx.random.normal(key=key)
|
a = mx.random.normal(key=key)
|
||||||
|
2
setup.py
2
setup.py
@ -165,7 +165,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx",
|
||||||
version=get_version("0.0.5"),
|
version=get_version("0.0.6"),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
|
@ -31,6 +31,7 @@ target_sources(tests PRIVATE
|
|||||||
scheduler_tests.cpp
|
scheduler_tests.cpp
|
||||||
utils_tests.cpp
|
utils_tests.cpp
|
||||||
vmap_tests.cpp
|
vmap_tests.cpp
|
||||||
|
linalg_tests.cpp
|
||||||
${METAL_TEST_SOURCES}
|
${METAL_TEST_SOURCES}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
250
tests/linalg_tests.cpp
Normal file
250
tests/linalg_tests.cpp
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
using namespace mlx::core::linalg;
|
||||||
|
|
||||||
|
TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||||
|
// Zero dimensions
|
||||||
|
array x(2.0);
|
||||||
|
CHECK_EQ(norm(x).item<float>(), 2.0f);
|
||||||
|
CHECK_THROWS(norm(x, 0));
|
||||||
|
|
||||||
|
x = array({1, 2, 3});
|
||||||
|
float expected = std::sqrt(1 + 4 + 9);
|
||||||
|
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, -1, true).ndim(), 1);
|
||||||
|
CHECK_THROWS(norm(x, 1));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
expected =
|
||||||
|
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
|
||||||
|
|
||||||
|
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK(array_equal(
|
||||||
|
norm(x, 0, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
|
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
|
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 1, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(0 + 1 + 2 * 2),
|
||||||
|
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2, false),
|
||||||
|
array(
|
||||||
|
{
|
||||||
|
std::sqrt(0 + 1 + 2 * 2),
|
||||||
|
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||||
|
std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||||
|
std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||||
|
std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||||
|
},
|
||||||
|
{2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, std::vector<int>{1, 2}, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(
|
||||||
|
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
|
||||||
|
8 * 8),
|
||||||
|
std::sqrt(
|
||||||
|
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
|
||||||
|
15 * 15 + 16 * 16 + 17 * 17)},
|
||||||
|
{2}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||||
|
CHECK_THROWS(norm(array(0), 2.0));
|
||||||
|
|
||||||
|
array x({1, 2, 3});
|
||||||
|
|
||||||
|
float expected = std::sqrt(1 + 4 + 9);
|
||||||
|
CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_THROWS(norm(x, 2.0, 1));
|
||||||
|
|
||||||
|
expected = 1 + 2 + 3;
|
||||||
|
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 3;
|
||||||
|
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 3;
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
|
||||||
|
doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 1;
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
||||||
|
doctest::Approx(expected));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2.0, 0, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
|
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
|
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2.0, 1, false),
|
||||||
|
array(
|
||||||
|
{sqrt(0 + 1 + 2 * 2),
|
||||||
|
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(15.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
|
||||||
|
doctest::Approx(21.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(9.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||||
|
doctest::Approx(3.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
|
doctest::Approx(9.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
|
doctest::Approx(15.0));
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
|
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 3.0, 0),
|
||||||
|
array(
|
||||||
|
{9.,
|
||||||
|
10.00333222,
|
||||||
|
11.02199456,
|
||||||
|
12.06217728,
|
||||||
|
13.12502645,
|
||||||
|
14.2094363,
|
||||||
|
15.31340617,
|
||||||
|
16.43469751,
|
||||||
|
17.57113899},
|
||||||
|
{3, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 3.0, 2),
|
||||||
|
array(
|
||||||
|
{2.08008382,
|
||||||
|
6.,
|
||||||
|
10.23127655,
|
||||||
|
14.5180117,
|
||||||
|
18.82291607,
|
||||||
|
23.13593104},
|
||||||
|
{2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(
|
||||||
|
allclose(
|
||||||
|
norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 1.0, 0),
|
||||||
|
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
||||||
|
.item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||||
|
array x({1, 2, 3});
|
||||||
|
CHECK_THROWS(norm(x, "fro"));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
CHECK_THROWS(norm(x, "bad ord"));
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, "f", std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(14.2828568570857));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(14.2828568570857));
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "fro", std::vector<int>{0, 1}),
|
||||||
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "fro", std::vector<int>{1, 2}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{0, 1}),
|
||||||
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{1, 0}),
|
||||||
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{1, 2}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{2, 1}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
|
.item<bool>());
|
||||||
|
}
|
@ -260,6 +260,10 @@ TEST_CASE("test random uniform") {
|
|||||||
// Non float type throws
|
// Non float type throws
|
||||||
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
|
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
|
||||||
|
|
||||||
|
// dtype respected
|
||||||
|
x = random::uniform(-.1, .1, {0}, bfloat16);
|
||||||
|
CHECK_EQ(x.dtype(), bfloat16);
|
||||||
|
|
||||||
// Check broadcasting
|
// Check broadcasting
|
||||||
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
||||||
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
||||||
|
Loading…
Reference in New Issue
Block a user