This commit is contained in:
paramthakkar123 2025-05-06 09:53:10 +05:30
commit 80002ed42f
128 changed files with 2291 additions and 895 deletions

1
.gitignore vendored
View File

@ -36,6 +36,7 @@ share/python-wheels/
.installed.cfg .installed.cfg
*.egg *.egg
MANIFEST MANIFEST
uv.lock
# vim # vim
*.swp *.swp

View File

@ -1,4 +1,6 @@
include CMakeLists.txt include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ * recursive-include mlx/ *
include cmake/*
include python/src/* include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561 include python/mlx/py.typed # support type hinting as in PEP-561

View File

@ -20,3 +20,5 @@ FFT
irfft2 irfft2
rfftn rfftn
irfftn irfftn
fftshift
ifftshift

View File

@ -47,7 +47,10 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif() endif()

View File

@ -356,7 +356,7 @@ class array {
} }
enum Status { enum Status {
// The ouptut of a computation which has not been scheduled. // The output of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`. // For example, the status of `x` in `auto x = a + b`.
unscheduled, unscheduled,

View File

@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
} }
} }
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m}; return {n, m};
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return true;
}
} // namespace mlx::core::cpu

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cpu {
bool is_available();
} // namespace mlx::core::cpu

View File

@ -172,9 +172,12 @@ void binary_float(
case bfloat16: case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt); binary_op<bfloat16_t, Op>(a, b, out, bopt);
break; break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types."); "[binary_float] Only supports floating point types.");
} }
}); });
} }

View File

@ -40,7 +40,10 @@ struct CompilerCache {
std::shared_mutex mtx; std::shared_mutex mtx;
}; };
static CompilerCache cache{}; static CompilerCache& cache() {
static CompilerCache cache_;
return cache_;
};
// GPU compile is always available if the GPU is available and since we are in // GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available. // this file CPU compile is also available.
@ -56,14 +59,16 @@ void* compile(
const std::string& kernel_name, const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) { const std::function<std::string(void)>& source_builder) {
{ {
std::shared_lock lock(cache.mtx); std::shared_lock lock(cache().mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second; return it->second;
} }
} }
std::unique_lock lock(cache.mtx); std::unique_lock lock(cache().mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second; return it->second;
} }
std::string source_code = source_builder(); std::string source_code = source_builder();
@ -120,10 +125,10 @@ void* compile(
} }
// load library // load library
cache.libs.emplace_back(shared_lib_path); cache().libs.emplace_back(shared_lib_path);
// Load function // Load function
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
if (!fun) { if (!fun) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function " msg << "[Compile::eval_cpu] Failed to load compiled function "
@ -131,7 +136,7 @@ void* compile(
<< dlerror(); << dlerror();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
cache.kernels.insert({kernel_name, fun}); cache().kernels.insert({kernel_name, fun});
return fun; return fun;
} }

View File

@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_type_, in, out, axis_, reverse_, inclusive_); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case complex64: case complex64:
throw std::runtime_error("Scan ops do not support complex types yet"); scan_dispatch<complex64_t, complex64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
} }
}); });

View File

@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log1p(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto x = in.value.real();
auto y = in.value.imag();
auto zabs = std::abs(in.value);
auto theta = std::atan2(y, x + 1);
if (zabs < 0.5) {
auto r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return Simd<T, 1>{T{x, theta}};
}
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
} else {
auto z0 = std::hypot(x + 1, y);
return Simd<T, 1>{T{std::log(z0), theta}};
}
} else {
return Simd<T, 1>{std::log1p(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) { Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) { if constexpr (is_complex<T>) {

View File

@ -0,0 +1,5 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::gpu {
bool is_available();
} // namespace mlx::core::gpu

49
mlx/backend/gpu/copy.cpp Normal file
View File

@ -0,0 +1,49 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
} // namespace mlx::core

View File

@ -5,6 +5,8 @@
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <optional>
namespace mlx::core { namespace mlx::core {
// Generic copy inplace // Generic copy inplace

View File

@ -8,14 +8,11 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/stream.h" #include "mlx/stream.h"
namespace mlx::core::metal { namespace mlx::core::gpu {
void new_stream(Stream stream); void new_stream(Stream stream);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
void eval(array& arr); void eval(array& arr);
void finalize(Stream s); void finalize(Stream s);
void synchronize(Stream s); void synchronize(Stream s);
} // namespace mlx::core::metal } // namespace mlx::core::gpu

View File

@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <cassert>
#define MLX_PROFILER_RANGE(message)
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsType::eval_gpu");
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Copy::eval_gpu");
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Depends::eval_gpu");
eval(inputs, outputs);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
eval(inputs, out);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Full::eval_gpu");
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Split::eval_gpu");
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Transpose::eval_gpu");
eval(inputs, out);
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("View::eval_gpu");
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@ -0,0 +1,44 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@ -93,6 +93,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp

View File

@ -1,7 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h" #include "mlx/backend/metal/resident.h"
#include "mlx/memory.h" #include "mlx/memory.h"

View File

@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > UINT32_MAX; large = out.data_size() > UINT32_MAX;
work_per_thread = 1; work_per_thread = get_work_per_thread(a.dtype());
} }
std::string kernel_name = std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
@ -137,13 +137,20 @@ void binary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims;
: MTL::Size(nthreads, 1, 1); if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@ -64,6 +64,7 @@ inline void build_kernel(
cnt++); cnt++);
} }
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (add_indices) { if (add_indices) {
os += fmt::format( os += fmt::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
@ -83,6 +84,9 @@ inline void build_kernel(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format( os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++); " constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else {
os += fmt::format(
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
} }
if (dynamic_dims) { if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
@ -92,13 +96,14 @@ inline void build_kernel(
os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n"; os += " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "int64_t" : "uint"; os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
if (contiguous && use_big_index) { if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have // This is only used for contiguous kernels which don't have
// a third grid dimension // a third grid dimension
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
} else if (contiguous) {
os += " uint index = N_ * pos.x;\n";
} else if (work_per_thread > 1) { } else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format( os += fmt::format(
" int xshape = output_shape[{0}];\n", " int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
@ -110,6 +115,9 @@ inline void build_kernel(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type); idx_type);
} }
if (work_per_thread > 1 && contiguous) {
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
}
// Read constant / contiguous inputs in tmps // Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs; std::vector<array> nc_inputs;
@ -193,7 +201,7 @@ inline void build_kernel(
} }
// Open per-thread loop // Open per-thread loop
if (work_per_thread > 1) { if (work_per_thread > 1 && !contiguous) {
os += os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
} }
@ -272,6 +280,7 @@ void Compiled::eval_gpu(
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() { auto lib = d.get_library(kernel_lib_, [&]() {
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
std::string kernel = metal::utils(); std::string kernel = metal::utils();
concatenate( concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
@ -284,7 +293,9 @@ void Compiled::eval_gpu(
constant_ids_, constant_ids_,
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false); /* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous_large", kernel_lib_ + "_contiguous_large",
@ -295,7 +306,8 @@ void Compiled::eval_gpu(
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
/* use_big_index = */ true); /* use_big_index = */ true,
/* work_per_thread = */ work_per_thread);
for (int i = 1; i < 8; i++) { for (int i = 1; i < 8; i++) {
build_kernel( build_kernel(
kernel, kernel,
@ -468,6 +480,13 @@ void Compiled::eval_gpu(
if (!contiguous) { if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++); compute_encoder.set_vector_bytes(shape, cnt++);
} else {
auto size = outputs[0].data_size();
if (large) {
compute_encoder.set_bytes<int64_t>(size, cnt++);
} else {
compute_encoder.set_bytes<int>(size, cnt++);
}
} }
// Put the number of dims in if it is dynamic // Put the number of dims in if it is dynamic
@ -477,12 +496,13 @@ void Compiled::eval_gpu(
// Launch the kernel // Launch the kernel
if (contiguous) { if (contiguous) {
size_t nthreads = outputs[0].data_size(); int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large MTL::Size grid_dims = large
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) ? get_2d_grid_dims(
outputs[0].shape(), outputs[0].strides(), work_per_thread)
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {

View File

@ -5,7 +5,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"

View File

@ -1,35 +1,15 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <sstream> #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
@ -104,6 +84,8 @@ void copy_gpu_inplace(
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
} }
} }
} else {
work_per_thread = get_work_per_thread(in.dtype());
} }
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@ -165,39 +147,23 @@ void copy_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
size_t nthreads = out.data_size(); size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims;
: MTL::Size(nthreads, 1, 1); if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) { void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
@ -214,14 +180,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0); compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = out.data_size(); size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims;
: MTL::Size(nthreads, 1, 1); if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }

View File

@ -1,6 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"

View File

@ -1,20 +1,20 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdlib> #include <cstdlib>
#include <filesystem>
#include <sstream> #include <sstream>
#include <sys/sysctl.h>
#define NS_PRIVATE_IMPLEMENTATION #define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal { namespace mlx::core::metal {
namespace { namespace {
@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(
if (bundle != nullptr) { if (bundle != nullptr) {
std::string resource_path = std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
lib_name + ".metallib" auto [lib, error] = lib_name + ".metallib";
load_library_from_path(device, resource_path.c_str()); auto [lib, error] = load_library_from_path(device, resource_path.c_str());
if (lib) { if (lib) {
return lib; return lib;
} }
@ -79,12 +79,18 @@ MTL::Library* try_load_bundle(
// Firstly, search for the metallib in the same path as this binary // Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library( std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device, MTL::Device* device,
const std::string& lib_name) { const std::string& relative_path) {
std::string lib_path = get_colocated_mtllib_path(lib_name); std::string binary_dir = get_binary_directory();
if (lib_path.size() != 0) { if (binary_dir.size() == 0) {
return load_library_from_path(device, lib_path.c_str()); return {nullptr, nullptr};
} }
return {nullptr, nullptr};
auto path = fs::path(binary_dir) / relative_path;
if (!path.has_extension()) {
path.replace_extension(".metallib");
}
return load_library_from_path(device, path.c_str());
} }
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library( std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
@ -99,7 +105,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
auto bundles = NS::Bundle::allBundles(); auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) { for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i)); auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL()); library = try_load_bundle(device, bundle->resourceURL(), lib_name);
if (library != nullptr) { if (library != nullptr) {
return {library, nullptr}; return {library, nullptr};
} }
@ -109,33 +115,34 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
} }
MTL::Library* load_default_library(MTL::Device* device) { MTL::Library* load_default_library(MTL::Device* device) {
NS::Error *error1, *error2, *error3; NS::Error* error[4];
MTL::Library* lib; MTL::Library* lib;
// First try the colocated mlx.metallib // First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx"); std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
if (lib) { if (lib) {
return lib; return lib;
} }
// Then try default.metallib in a SwiftPM bundle if we have one // Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default"); std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
if (lib) { if (lib) {
return lib; return lib;
} }
// Finally try default_mtllib_path // Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path); std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
if (!lib) { if (!lib) {
std::ostringstream msg; std::ostringstream msg;
msg << "Failed to load the default metallib. "; msg << "Failed to load the default metallib. ";
if (error1 != nullptr) { for (int i = 0; i < 4; i++) {
msg << error1->localizedDescription()->utf8String() << " "; if (error[i] != nullptr) {
} msg << error[i]->localizedDescription()->utf8String() << " ";
if (error2 != nullptr) { }
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
} }
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
@ -156,6 +163,7 @@ MTL::Library* load_library(
<< error->localizedDescription()->utf8String(); << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
return lib;
} }
// We have been given a path so try to load from lib_path / lib_name.metallib // We have been given a path so try to load from lib_path / lib_name.metallib
@ -168,6 +176,7 @@ MTL::Library* load_library(
<< "> with error " << error->localizedDescription()->utf8String(); << "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
return lib;
} }
// Try to load the colocated library // Try to load the colocated library
@ -188,8 +197,8 @@ MTL::Library* load_library(
std::ostringstream msg; std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. " msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name) << "We attempted to load it from <" << get_binary_directory() << "/"
<< ">"; << lib_name << ".metallib" << ">";
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle."; msg << " and from the Swift PM bundle.";
#endif #endif
@ -760,42 +769,4 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
NS::AutoreleasePool::alloc()->init(), dtor); NS::AutoreleasePool::alloc()->init(), dtor);
} }
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}
}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;
}
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -21,18 +21,14 @@ namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not // Note, this function must be left inline in a header so that it is not
// dynamically linked. // dynamically linked.
inline std::string get_colocated_mtllib_path(const std::string& lib_name) { inline std::string get_binary_directory() {
Dl_info info; Dl_info info;
std::string mtllib_path; std::string directory;
std::string lib_ext = lib_name + ".metallib"; int success = dladdr((void*)get_binary_directory, &info);
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) { if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; directory = fs::path(info.dli_fname).remove_filename().c_str();
mtllib_path = mtllib.c_str();
} }
return directory;
return mtllib_path;
} }
using MTLFCList = using MTLFCList =
@ -270,4 +266,6 @@ class Device {
Device& device(mlx::core::Device); Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -4,7 +4,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"

102
mlx/backend/metal/eval.cpp Normal file
View File

@ -0,0 +1,102 @@
// Copyright © 2023-2024 Apple Inc.
#include <memory>
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
metal::device(stream.device).new_queue(stream.index);
}
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
void eval(array& arr) {
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
} // namespace mlx::core::gpu

View File

@ -2,7 +2,6 @@
#include "mlx/event.h" #include "mlx/event.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
namespace mlx::core { namespace mlx::core {

View File

@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/fence.h" #include "mlx/fence.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -139,7 +138,7 @@ void Fence::update(Stream stream, const array& x) {
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1); compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(group_dims, grid_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Barrier on previous kernels // Barrier on previous kernels
compute_encoder.barrier(); compute_encoder.barrier();

View File

@ -7,10 +7,10 @@
#include "mlx/3rdparty/pocketfft.h" #include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/binary.h" #include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"

View File

@ -1,11 +1,9 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h" #include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
@ -15,7 +13,6 @@
namespace mlx::core { namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) { std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M // Generate a O(m^2) hadamard codelet for a given M
@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) {
return source.str(); return source.str();
} }
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) { void hadamard_mn_contiguous(
auto& s = stream(); const array& x,
array& y,
int m,
int n1,
int n2,
float scale,
metal::Device& d,
const Stream& s) {
int n = n1 * n2;
int read_width_n1 = n1 == 2 ? 2 : 4;
int read_width_n2 = n2 == 2 ? 2 : 4;
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
int max_radix_1 = std::min(n1, 16);
int max_radix_2 = std::min(n2, 16);
float scale_n1 = 1.0;
float scale_n2 = (m == 1) ? scale : 1.0;
float scale_m = scale;
auto& in = inputs[0]; // n2 is a row contiguous power of 2 hadamard transform
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
std::vector<array> copies; // n1 is a strided power of 2 hadamard transform with stride n2
// Only support the last axis for now MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
int axis = in.ndim() - 1; MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint // m is a strided hadamard transform with stride n = n1 * n2
bool no_copy = x.flags().row_contiguous; MTL::Size group_dims_m(
if (no_copy) { std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
return x; MTL::Size grid_dims_m(
} else { group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s); // Make the kernel
return copies.back(); std::string kname;
kname.reserve(32);
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
auto lib = d.get_library(kname, [&]() {
std::string kernel;
concatenate(
kernel,
metal::utils(),
gen_hadamard_codelet(m),
metal::hadamard(),
get_template_definition(
"n2" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n2,
max_radix_2,
read_width_n2));
if (n1 > 1) {
kernel += get_template_definition(
"n1" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n1,
max_radix_1,
read_width_n1,
n2);
} }
}; if (m > 1) {
const array& in_contiguous = check_input(in); kernel += get_template_definition(
"m" + kname,
if (in_contiguous.is_donatable()) { "hadamard_m",
out.copy_shared_buffer(in_contiguous); get_type_string(x.dtype()),
} else { n,
out.set_data(allocator::malloc(out.nbytes())); m,
} read_width_m);
}
int n, m; return kernel;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
return kernel_source.str();
}); });
int batch_size = in.size() / n; // Launch the strided transform for n1
int threads_per = n / max_radix; if (n1 > 1) {
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel("n1" + kname, lib);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(x, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale, 2); compute_encoder.set_bytes(scale_n1, 2);
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
} }
d.add_temporaries(std::move(copies), s.index); // Launch the transform for n2
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n2" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n2, 2);
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
// Launch the strided transform for m
if (m > 1) {
auto kernel = d.get_kernel("m" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(y, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_m, 2);
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
}
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
// Split the hadamard transform so that all of them work on vectors smaller
// than 8192 elements.
//
// We decompose it in the following way:
//
// n = m * n1 * n2 = m * 2^k1 * 2^k2
//
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
auto [n, m] = decompose_hadamard(in.shape().back());
int n1 = 1, n2 = n;
if (n > 8192) {
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
}
n1 = n / n2;
}
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
} else {
copy_gpu(in, out, CopyType::General, s);
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -2,7 +2,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h" #include "mlx/backend/metal/jit/indexing.h"

View File

@ -9,64 +9,85 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[0], b[0]); c[index] = Op()(a[0], b[0]);
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv( [[kernel]] void binary_sv(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]); index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs( [[kernel]] void binary_vs(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]); index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv( [[kernel]] void binary_vv(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]); index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2( [[kernel]] void binary_sv2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
c[offset] = Op()(a[0], b[offset]); for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2( [[kernel]] void binary_vs2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
c[offset] = Op()(a[offset], b[0]); for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2( [[kernel]] void binary_vv2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
c[offset] = Op()(a[offset], b[offset]); for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} }
template <typename T, typename U, typename Op, typename IdxT = int64_t> template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual) instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp) instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum) instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum) instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply) instantiate_binary_types(Multiply)

View File

@ -130,6 +130,24 @@ struct LogAddExp {
? maxval ? maxval
: (maxval + log1p(metal::exp(minval - maxval))); : (maxval + log1p(metal::exp(minval - maxval)));
}; };
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
metal::isnan(y.imag)) {
return metal::numeric_limits<float>::quiet_NaN();
}
constexpr float inf = metal::numeric_limits<float>::infinity();
complex64_t maxval = x > y ? x : y;
complex64_t minval = x < y ? x : y;
if (minval.real == -inf || maxval.real == inf)
return maxval;
float m = metal::exp(minval.real - maxval.real);
complex64_t dexp{
m * metal::cos(minval.imag - maxval.imag),
m * metal::sin(minval.imag - maxval.imag),
};
return maxval + log1p(dexp);
}
}; };
struct Maximum { struct Maximum {

View File

@ -12,82 +12,103 @@ template <typename T, typename U, typename Op>
d[index] = out[1]; d[index] = out[1];
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv( [[kernel]] void binary_sv(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[index]); index *= N;
c[index] = out[0]; for (int i = 0; i < N && (index + i) < size; ++i) {
d[index] = out[1]; auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs( [[kernel]] void binary_vs(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[0]); index *= N;
c[index] = out[0]; for (int i = 0; i < N && (index + i) < size; ++i) {
d[index] = out[1]; auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv( [[kernel]] void binary_vv(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[index]); index *= N;
c[index] = out[0]; for (int i = 0; i < N && (index + i) < size; ++i) {
d[index] = out[1]; auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2( [[kernel]] void binary_sv2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y); auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
auto out = Op()(a[0], b[offset]); for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset] = out[0]; auto out = Op()(a[0], b[offset + i]);
d[offset] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2( [[kernel]] void binary_vs2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y); auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
auto out = Op()(a[offset], b[0]); for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset] = out[0]; auto out = Op()(a[offset + i], b[0]);
d[offset] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2( [[kernel]] void binary_vv2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y); auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
auto out = Op()(a[offset], b[offset]); for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset] = out[0]; auto out = Op()(a[offset + i], b[offset + i]);
d[offset] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
template <typename T, typename U, typename Op, typename IdxT = int64_t> template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -1,39 +1,53 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
template <typename T, typename U> template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_s( [[kernel]] void copy_s(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]); index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} }
template <typename T, typename U> template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_v( [[kernel]] void copy_v(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]); index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} }
template <typename T, typename U> template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_s2( [[kernel]] void copy_s2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y); auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
dst[offset] = static_cast<U>(src[0]); for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} }
template <typename T, typename U> template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_v2( [[kernel]] void copy_v2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y); auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
dst[offset] = static_cast<U>(src[offset]); for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
} }
template <typename T, typename U, typename IdxT = int64_t> template <typename T, typename U, typename IdxT = int64_t>

View File

@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so
read/write performance is important. read/write performance is important.
Where possible, we read 128 bits sequentially in each thread, Where possible, we read 128 bits sequentially in each thread,
coalesced with accesses from adajcent threads for optimal performance. coalesced with accesses from adjacent threads for optimal performance.
We implement specialized reading/writing for: We implement specialized reading/writing for:
- FFT - FFT

View File

@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) {
} }
} }
template <typename T, int N, int max_radix, int read_width> template <typename T, int N, int max_radix, int read_width, int stride = 1>
[[kernel]] void hadamard_n( [[kernel]] void hadamard_n(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
@ -46,18 +46,25 @@ template <typename T, int N, int max_radix, int read_width>
constexpr short logFinal = logN % logR; constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal); constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N; int batch_idx = elem.y * N * stride + elem.z;
short i = elem.y; short i = elem.x;
threadgroup T buf[N]; threadgroup T buf[N];
// Read values from device // Read values from device
STEEL_PRAGMA_UNROLL if (stride == 1) {
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) { for (short j = 0; j < max_radix / read_width; j++) {
buf[index + r] = in[batch_idx + index + r]; short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
} }
} }
@ -113,12 +120,20 @@ template <typename T, int N, int max_radix, int read_width>
} }
// Write values to device // Write values to device
STEEL_PRAGMA_UNROLL if (stride == 1) {
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) { for (short j = 0; j < max_radix / read_width; j++) {
out[batch_idx + index + r] = T(buf[index + r] * scale); short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
out[batch_idx + (j * num_threads + i) * stride] =
buf[j * num_threads + i];
} }
} }
} }

View File

@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
auto wl = (const device uint8_t*)w; auto wl = (const device uint8_t*)w;
x += y_row * K; x += y_row * static_cast<int64_t>(K);
wl += y_col * K_w; wl += y_col * K_w;
scales += y_col * K_g; scales += y_col * K_g;
biases += y_col * K_g; biases += y_col * K_g;
y += y_row * N + y_col; y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation // Make the x loader and mma operation
const short num_els = min(BM, M - y_row); const short num_els = min(BM, M - y_row);
@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
// Set the block // Set the block
const int y_row = tid.y * BM; const int y_row = tid.y * BM;
const int y_col = tid.x * BN; const int y_col = tid.x * BN;
x += y_row * K; x += y_row * static_cast<int64_t>(K);
wl += y_col * bytes_per_pack / pack_factor; wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size; scales += y_col / group_size;
biases += y_col / group_size; biases += y_col / group_size;
y += y_row * N + y_col; y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation // Make the x loader and mma operation
const short num_els = min(BM, M - y_row); const short num_els = min(BM, M - y_row);

View File

@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on

View File

@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
const int head_idx = tid.x; const int head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = tpg.x * q_seq_idx + head_idx; const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
const int block_idx = tid.z; const int block_idx = tid.z;
const int head_idx = tid.x; const int head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int o_offset = tpg.x * q_seq_idx + head_idx; const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
@ -358,8 +358,8 @@ template <typename T, int D>
// Adjust positions // Adjust positions
const int head_idx = tid.x; const int head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int n_heads = tpg.x; const int q_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset = n_heads * q_seq_idx + head_idx; ;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks; sums += q_offset * blocks;
maxs += q_offset * blocks; maxs += q_offset * blocks;

View File

@ -95,7 +95,7 @@ template <
Q += tidl.z * params->Q_strides[0] + // Batch Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce tidl.x * BQ * params->Q_strides[2]; // Sequence
ulong kv_head_idx = int(tid.y) / params->gqa_factor; ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch K += tidl.z * params->K_strides[0] + // Batch
@ -106,7 +106,7 @@ template <
O += tidl.z * params->O_strides[0] + // Batch O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce tidl.x * BQ * params->O_strides[2]; // Sequence
if (has_mask) { if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch mask += tidl.z * mask_params->M_strides[0] + // Batch

View File

@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
} }
// Zero out uneeded values // Zero out unneeded values
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) { for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
@ -240,7 +240,7 @@ struct BlockLoaderT {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
} }
// Zero out uneeded values // Zero out unneeded values
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) { for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general(
// Store results to device memory // Store results to device memory
{ {
// Adjust for simdgroup and thread locatio // Adjust for simdgroup and thread location
int offset_m = c_row + mma_op.sm; int offset_m = c_row + mma_op.sm;
int offset_n = c_col + mma_op.sn; int offset_n = c_col + mma_op.sn;
C += offset_n; C += offset_n;

View File

@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
} }
// Zero out uneeded values // Zero out unneeded values
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) { for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@ -1,25 +1,32 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
template <typename T, typename Op> template <typename T, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v( [[kernel]] void ternary_v(
device const bool* a, device const bool* a,
device const T* b, device const T* b,
device const T* c, device const T* c,
device T* d, device T* d,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
d[index] = Op()(a[index], b[index], c[index]); index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
} }
template <typename T, typename Op> template <typename T, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v2( [[kernel]] void ternary_v2(
device const bool* a, device const bool* a,
device const T* b, device const T* b,
device const T* c, device const T* c,
device T* d, device T* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y); auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
d[offset] = Op()(a[offset], b[offset], c[offset]); for (int i = 0; i < N && (offset + i) < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
} }
template <typename T, typename Op, typename IdxT = int64_t> template <typename T, typename Op, typename IdxT = int64_t>

View File

@ -1,21 +1,28 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void unary_v( [[kernel]] void unary_v(
device const T* in, device const T* in,
device U* out, device U* out,
constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]); index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void unary_v2( [[kernel]] void unary_v2(
device const T* in, device const T* in,
device U* out, device U* out,
constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y); auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
out[offset] = Op()(in[offset]); for (int i = 0; i < N && (offset + i) < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
} }
template < template <

View File

@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t) instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t)

View File

@ -15,6 +15,14 @@
typedef half float16_t; typedef half float16_t;
// Work per thread values for different types. The values here are expected to
// match get_work_per_thread in mlx/backend/metal/utils.h
template <typename U>
struct WorkPerThread {
static_assert(sizeof(U) <= 8, "Type too large");
static constexpr int constant n = 8 / sizeof(U);
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -328,6 +336,23 @@ inline bfloat16_t log1p(bfloat16_t x) {
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
} }
inline complex64_t log1p(complex64_t in) {
float x = in.real;
float y = in.imag;
float zabs = metal::precise::sqrt(x * x + y * y);
float theta = metal::atan2(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1p(r), theta};
} else {
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
return {metal::log(z0), theta};
}
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops // SIMD shuffle ops
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"

View File

@ -7,7 +7,7 @@
#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"

View File

@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <memory> #include <memory>
#include <sys/sysctl.h>
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::metal { namespace mlx::core::metal {
@ -13,85 +13,6 @@ bool is_available() {
return true; return true;
} }
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
void eval(array& arr) {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
void start_capture(std::string path, id object) { void start_capture(std::string path, id object) {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
@ -128,4 +49,36 @@ void stop_capture() {
manager->stopCapture(); manager->stopCapture();
} }
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;
}
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -2,11 +2,10 @@
#pragma once #pragma once
#include <string>
#include <unordered_map> #include <unordered_map>
#include <variant> #include <variant>
#include "mlx/array.h"
namespace mlx::core::metal { namespace mlx::core::metal {
/* Check if the Metal backend is available. */ /* Check if the Metal backend is available. */

View File

@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
namespace mlx::core::metal {
bool is_available() {
return false;
}
void start_capture(std::string) {}
void stop_capture() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal

View File

@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/reduce.h"

View File

@ -7,10 +7,10 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1); enc.set_bytes(step, 1);
} }
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
static array compute_dynamic_offset( static array compute_dynamic_offset(
const array& indices, const array& indices,
const Strides& strides, const Strides& strides,
@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) { void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Load::eval_gpu] Not implemented."); throw std::runtime_error("[Load::eval_gpu] Not implemented.");
} }
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) { void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) { void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(nullptr);
@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const Stream& s = */ stream()); /* const Stream& s = */ stream());
} }
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void QRF::eval_gpu( void QRF::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@ -537,35 +390,4 @@ void LUF::eval_gpu(
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
} }
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -4,7 +4,7 @@
#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/reduce.h"

View File

@ -3,7 +3,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"

View File

@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/resident.h" #include "mlx/backend/metal/resident.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal { namespace mlx::core::metal {

View File

@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"

View File

@ -2,7 +2,7 @@
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/kernels/steel/attn/params.h"
@ -154,9 +154,9 @@ void sdpa_vector(
int gqa_factor = q.shape(1) / k.shape(1); int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2); int N = k.shape(2);
int B = q.shape(0) * q.shape(1); int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1]; size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2]; size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2]; size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1); MTL::Size group_dims(1024, 1, 1);
@ -199,11 +199,10 @@ void sdpa_vector(
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 11 + float_mask); compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim(); int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t kv_seq_stride = int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t head_stride =
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15); compute_encoder.set_bytes(head_stride, 15);
@ -238,9 +237,10 @@ void sdpa_vector_2pass(
int N = k.shape(2); int N = k.shape(2);
int blocks = 32; int blocks = 32;
int B = q.shape(0) * q.shape(1); int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1];
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2]; size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2]; size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(8 * 32, 1, 1); MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(B, q.shape(2), blocks); MTL::Size grid_dims(B, q.shape(2), blocks);
@ -302,11 +302,10 @@ void sdpa_vector_2pass(
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 13 + float_mask); compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim(); int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t kv_seq_stride = int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t head_stride =
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17); compute_encoder.set_bytes(head_stride, 17);
@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu(
} }
}; };
// Checks if arr is row contiguous or the sequence and head dimension are
// transposed
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
};
// Checks that the headdim dimension has stride 1. // Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) { auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1; return arr.strides(-1) == 1;
@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) <= 8) { if (q_pre.shape(2) <= 8) {
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); auto q_copy_unless = [](const array& arr) {
const auto& k = copy_unless(is_matrix_contiguous, k_pre); if (arr.flags().row_contiguous) {
const auto& v = copy_unless(is_matrix_contiguous, v_pre); return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
// Donate the query if possible // Donate the query if possible
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
q.size() == o.size()) {
o.copy_shared_buffer(q); o.copy_shared_buffer(q);
} else { } else {
if (o.shape(2) == 1) { o.set_data(allocator::malloc(o.nbytes()));
o.set_data(allocator::malloc(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
strides[1] = o.shape(3);
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
}
} }
auto mask = auto mask_copy_unless = [&q](const array& arr) {
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt; auto& strides = arr.strides();
auto& shape = arr.shape();
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
// We route to the 2 pass fused attention if // We route to the 2 pass fused attention if
// - The device is large and the sequence length long // - The device is large and the sequence length long

View File

@ -3,7 +3,7 @@
#include <cassert> #include <cassert>
#include <sstream> #include <sstream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"

View File

@ -2,21 +2,12 @@
#include <numeric> #include <numeric>
#include "mlx/backend/common/slicing.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
namespace mlx::core { namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void concatenate_gpu( void concatenate_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
@ -48,30 +39,4 @@ void concatenate_gpu(
} }
} }
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"

View File

@ -2,7 +2,7 @@
#include <algorithm> #include <algorithm>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"

View File

@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > INT32_MAX; large = out.data_size() > INT32_MAX;
work_per_thread = 1; work_per_thread = get_work_per_thread(b.dtype());
} }
std::string kernel_name; std::string kernel_name;
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
@ -106,13 +106,19 @@ void ternary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims;
: MTL::Size(nthreads, 1, 1); if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 4);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 4);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@ -34,18 +34,19 @@ void unary_op_gpu_inplace(
}; };
auto [shape, strides] = maybe_collapse(); auto [shape, strides] = maybe_collapse();
int ndim = shape.size(); int ndim = shape.size();
size_t nthreads = contig ? in.data_size() : in.size();
bool large; bool large;
if (!contig) { if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else { } else {
large = in.data_size() > UINT32_MAX; large = in.data_size() > UINT32_MAX;
} }
int work_per_thread = !contig && large ? 4 : 1; int work_per_thread;
std::string kernel_name; std::string kernel_name;
if (contig) { if (contig) {
work_per_thread = get_work_per_thread(in.dtype());
kernel_name = (large ? "v2" : "v"); kernel_name = (large ? "v2" : "v");
} else { } else {
work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread); kernel_name = "gn" + std::to_string(work_per_thread);
if (large) { if (large) {
kernel_name += "large"; kernel_name += "large";
@ -75,12 +76,20 @@ void unary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
size_t nthreads = ceildiv(in.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims;
: MTL::Size(nthreads, 1, 1); if (large) {
compute_encoder.set_bytes<int64_t>(in.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(in.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) {
concatenate(acc, args...); concatenate(acc, args...);
} }
inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return false;
}
} // namespace mlx::core::cpu

View File

@ -18,7 +18,7 @@ void Compiled::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
throw std::runtime_error( throw std::runtime_error(
"[Compiled::eval_cpu] CPU compialtion not supported on the platform."); "[Compiled::eval_cpu] CPU compilation not supported on the platform.");
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -3,5 +3,5 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)

View File

@ -6,9 +6,9 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#ifdef __APPLE__ #ifdef __APPLE__
#include "mlx/backend/no_metal/apple_memory.h" #include "mlx/backend/no_gpu/apple_memory.h"
#elif defined(__linux__) #elif defined(__linux__)
#include "mlx/backend/no_metal/linux_memory.h" #include "mlx/backend/no_gpu/linux_memory.h"
#else #else
size_t get_memory_size() { size_t get_memory_size() {
return 0; return 0;

View File

@ -0,0 +1,28 @@
// Copyright © 2025 Apple Inc.
#include <stdexcept>
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
namespace mlx::core::gpu {
bool is_available() {
return false;
}
void new_stream(Stream) {}
void eval(array&) {
throw std::runtime_error("[gpu::eval] GPU backend is not available");
}
void finalize(Stream) {
throw std::runtime_error("[gpu::finalize] GPU backend is not available");
}
void synchronize(Stream) {
throw std::runtime_error("[gpu::synchronize] GPU backend is not available");
}
} // namespace mlx::core::gpu

View File

@ -1,43 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {
bool is_available() {
return false;
}
void new_stream(Stream) {}
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
return nullptr;
}
void eval(array&) {
throw std::runtime_error(
"[metal::eval] Cannot eval on GPU without metal backend");
}
void finalize(Stream) {
throw std::runtime_error(
"[metal::finalize] Cannot finalize GPU without metal backend");
}
void synchronize(Stream) {
throw std::runtime_error(
"[metal::synchronize] Cannot synchronize GPU without metal backend");
}
void start_capture(std::string) {}
void stop_capture() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal

View File

@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
parent.first.inputs()[parent.second] = dst; parent.first.inputs()[parent.second] = dst;
pairs.push_back(parent); pairs.push_back(parent);
} }
// If src is a parent of dst, remove it from dst's parents
for (auto it = pairs.begin(); it != pairs.end();) {
if (it->first.id() == src.id()) {
it = pairs.erase(it);
} else {
it++;
}
}
// Remove the source from the map to avoid fusing with it again // Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents); parents_map.erase(src_parents);
} }

View File

@ -1,23 +1,28 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <stdexcept>
#include "mlx/backend/cpu/available.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core { namespace mlx::core {
static Device default_device_{ Device& mutable_default_device() {
metal::is_available() ? Device::gpu : Device::cpu}; static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
return default_device;
}
const Device& default_device() { const Device& default_device() {
return default_device_; return mutable_default_device();
} }
void set_default_device(const Device& d) { void set_default_device(const Device& d) {
if (!metal::is_available() && d == Device::gpu) { if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument( throw std::invalid_argument(
"[set_default_device] Cannot set gpu device without gpu backend."); "[set_default_device] Cannot set gpu device without gpu backend.");
} }
default_device_ = d; mutable_default_device() = d;
} }
bool operator==(const Device& lhs, const Device& rhs) { bool operator==(const Device& lhs, const Device& rhs) {
@ -28,4 +33,15 @@ bool operator!=(const Device& lhs, const Device& rhs) {
return !(lhs == rhs); return !(lhs == rhs);
} }
bool is_available(const Device& d) {
switch (d.type) {
case Device::cpu:
return cpu::is_available();
case Device::gpu:
return gpu::is_available();
}
// appease compiler
return false;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -26,4 +26,6 @@ void set_default_device(const Device& d);
bool operator==(const Device& lhs, const Device& rhs); bool operator==(const Device& lhs, const Device& rhs);
bool operator!=(const Device& lhs, const Device& rhs); bool operator!=(const Device& lhs, const Device& rhs);
bool is_available(const Device& d);
} // namespace mlx::core } // namespace mlx::core

View File

@ -470,6 +470,9 @@ bool FunctionTable::match(
if (x.dtype() != y.dtype()) { if (x.dtype() != y.dtype()) {
return false; return false;
} }
if (x.ndim() != y.ndim()) {
return false;
}
if (!shapeless && x.shape() != y.shape()) { if (!shapeless && x.shape() != y.shape()) {
return false; return false;
} }

View File

@ -186,6 +186,7 @@ array irfftn(
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fft_impl(a, axes, true, true, s); return fft_impl(a, axes, true, true, s);
} }
array irfftn(const array& a, StreamOrDevice s /* = {} */) { array irfftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, true, true, s); return fft_impl(a, true, true, s);
} }
@ -308,4 +309,73 @@ array istft(
return signal; return signal;
} }
array fftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
Shape shifts;
for (int ax : axes) {
// Convert negative axes to positive
int axis = ax < 0 ? ax + a.ndim() : ax;
if (axis < 0 || axis >= a.ndim()) {
std::ostringstream msg;
msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Match NumPy's implementation
shifts.push_back(a.shape(axis) / 2);
}
return roll(a, shifts, axes, s);
}
array ifftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
Shape shifts;
for (int ax : axes) {
// Convert negative axes to positive
int axis = ax < 0 ? ax + a.ndim() : ax;
if (axis < 0 || axis >= a.ndim()) {
std::ostringstream msg;
msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Match NumPy's implementation
int size = a.shape(axis);
shifts.push_back(-(size / 2));
}
return roll(a, shifts, axes, s);
}
// Default versions that operate on all axes
array fftshift(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() < 1) {
return a;
}
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return fftshift(a, axes, s);
}
array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() < 1) {
return a;
}
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return ifftshift(a, axes, s);
}
} // namespace mlx::core::fft } // namespace mlx::core::fft

View File

@ -148,6 +148,24 @@ inline array irfft2(
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return irfftn(a, axes, s); return irfftn(a, axes, s);
} }
/** Shift the zero-frequency component to the center of the spectrum. */
array fftshift(const array& a, StreamOrDevice s = {});
/** Shift the zero-frequency component to the center of the spectrum along
* specified axes. */
array fftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** The inverse of fftshift. */
array ifftshift(const array& a, StreamOrDevice s = {});
/** The inverse of fftshift along specified axes. */
array ifftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array stft( inline array stft(
const array& x, const array& x,

View File

@ -335,7 +335,10 @@ ThreadPool& thread_pool() {
return pool_; return pool_;
} }
ThreadPool ParallelFileReader::thread_pool_{4}; ThreadPool& ParallelFileReader::thread_pool() {
static ThreadPool thread_pool{4};
return thread_pool;
}
void ParallelFileReader::read(char* data, size_t n) { void ParallelFileReader::read(char* data, size_t n) {
while (n != 0) { while (n != 0) {
@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
break; break;
} else { } else {
size_t m = batch_size_; size_t m = batch_size_;
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); futs.emplace_back(
ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data));
data += m; data += m;
n -= m; n -= m;
offset += m; offset += m;

View File

@ -101,7 +101,7 @@ class ParallelFileReader : public Reader {
private: private:
static constexpr size_t batch_size_ = 1 << 25; static constexpr size_t batch_size_ = 1 << 25;
static ThreadPool thread_pool_; static ThreadPool& thread_pool();
int fd_; int fd_;
std::string label_; std::string label_;
}; };

View File

@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
// Prepare S // Prepare S
S = expand_dims(S, -2, s); S = expand_dims(S, -2, s);
return matmul(divide(V, S, s), U); auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps;
auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s);
auto rS =
where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s);
return matmul(multiply(V, rS, s), U, s);
} }
array cholesky_inv( array cholesky_inv(

View File

@ -473,8 +473,19 @@ array hadamard_transform(
std::optional<float> scale_ /* = std::nullopt */, std::optional<float> scale_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); int n = a.ndim() > 0 ? a.shape(-1) : 1;
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n);
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
// Nothing to do for a scalar
if (n == 1) {
if (scale == 1) {
return a;
}
return multiply(a, array(scale, dtype), s);
}
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
@ -3769,6 +3780,7 @@ array conv_transpose_general(
std::vector<int> stride, std::vector<int> stride,
std::vector<int> padding, std::vector<int> padding,
std::vector<int> dilation, std::vector<int> dilation,
std::vector<int> output_padding,
int groups, int groups,
StreamOrDevice s) { StreamOrDevice s) {
std::vector<int> padding_lo(padding.size()); std::vector<int> padding_lo(padding.size());
@ -3782,7 +3794,8 @@ array conv_transpose_general(
int in_size = 1 + (conv_output_shape - 1); int in_size = 1 + (conv_output_shape - 1);
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1); int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding[i]; padding_hi[i] = in_size - out_size + padding[i] +
output_padding[i]; // Adjust with output_padding
} }
return conv_general( return conv_general(
@ -3805,10 +3818,11 @@ array conv_transpose1d(
int stride /* = 1 */, int stride /* = 1 */,
int padding /* = 0 */, int padding /* = 0 */,
int dilation /* = 1 */, int dilation /* = 1 */,
int output_padding /* = 0 */,
int groups /* = 1 */, int groups /* = 1 */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return conv_transpose_general( return conv_transpose_general(
in_, wt_, {stride}, {padding}, {dilation}, groups, s); in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s);
} }
/** 2D transposed convolution with a filter */ /** 2D transposed convolution with a filter */
@ -3818,6 +3832,7 @@ array conv_transpose2d(
const std::pair<int, int>& stride /* = {1, 1} */, const std::pair<int, int>& stride /* = {1, 1} */,
const std::pair<int, int>& padding /* = {0, 0} */, const std::pair<int, int>& padding /* = {0, 0} */,
const std::pair<int, int>& dilation /* = {1, 1} */, const std::pair<int, int>& dilation /* = {1, 1} */,
const std::pair<int, int>& output_padding /* = {0, 0} */,
int groups /* = 1 */, int groups /* = 1 */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return conv_transpose_general( return conv_transpose_general(
@ -3826,6 +3841,7 @@ array conv_transpose2d(
{stride.first, stride.second}, {stride.first, stride.second},
{padding.first, padding.second}, {padding.first, padding.second},
{dilation.first, dilation.second}, {dilation.first, dilation.second},
{output_padding.first, output_padding.second},
groups, groups,
s); s);
} }
@ -3837,6 +3853,7 @@ array conv_transpose3d(
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */, const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */, const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */, const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
const std::tuple<int, int, int>& output_padding /* = {0, 0, 0} */,
int groups /* = 1 */, int groups /* = 1 */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return conv_transpose_general( return conv_transpose_general(
@ -3845,6 +3862,9 @@ array conv_transpose3d(
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
{std::get<0>(output_padding),
std::get<1>(output_padding),
std::get<2>(output_padding)},
groups, groups,
s); s);
} }
@ -4873,8 +4893,9 @@ array bitwise_impl(
const array& b, const array& b,
BitwiseBinary::Op op, BitwiseBinary::Op op,
const std::string& op_name, const std::string& op_name,
const StreamOrDevice& s) { const StreamOrDevice& s,
auto out_type = promote_types(a.dtype(), b.dtype()); std::optional<Dtype> out_type_ = std::nullopt) {
auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype());
if (!(issubdtype(out_type, integer) || out_type == bool_)) { if (!(issubdtype(out_type, integer) || out_type == bool_)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << op_name msg << "[" << op_name
@ -4919,12 +4940,7 @@ array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (t == bool_) { if (t == bool_) {
t = uint8; t = uint8;
} }
return bitwise_impl( return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t);
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::LeftShift,
"left_shift",
s);
} }
array operator<<(const array& a, const array& b) { array operator<<(const array& a, const array& b) {
return left_shift(a, b); return left_shift(a, b);
@ -4940,7 +4956,8 @@ array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
astype(b, t, s), astype(b, t, s),
BitwiseBinary::Op::RightShift, BitwiseBinary::Op::RightShift,
"right_shift", "right_shift",
s); s,
t);
} }
array operator>>(const array& a, const array& b) { array operator>>(const array& a, const array& b) {
return right_shift(a, b); return right_shift(a, b);
@ -5019,8 +5036,11 @@ array roll(
} }
auto sh = shift[i]; auto sh = shift[i];
auto split_index = auto size = a.shape(ax);
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); if (size == 0) {
continue; // skip rolling this axis if it has size 0
}
auto split_index = (sh < 0) ? (-sh) % size : size - sh % size;
auto parts = split(result, Shape{split_index}, ax, s); auto parts = split(result, Shape{split_index}, ax, s);
std::swap(parts[0], parts[1]); std::swap(parts[0], parts[1]);

View File

@ -569,7 +569,7 @@ inline array std(const array& a, StreamOrDevice s = {}) {
return std(a, false, 0, to_stream(s)); return std(a, false, 0, to_stream(s));
} }
/** Computes the standard deviatoin of the elements of an array along the given /** Computes the standard deviation of the elements of an array along the given
* axes */ * axes */
array std( array std(
const array& a, const array& a,
@ -1291,6 +1291,7 @@ array conv_transpose1d(
int stride = 1, int stride = 1,
int padding = 0, int padding = 0,
int dilation = 1, int dilation = 1,
int output_padding = 0,
int groups = 1, int groups = 1,
StreamOrDevice s = {}); StreamOrDevice s = {});
@ -1301,6 +1302,7 @@ array conv_transpose2d(
const std::pair<int, int>& stride = {1, 1}, const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0}, const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1}, const std::pair<int, int>& dilation = {1, 1},
const std::pair<int, int>& output_padding = {0, 0},
int groups = 1, int groups = 1,
StreamOrDevice s = {}); StreamOrDevice s = {});
@ -1311,6 +1313,7 @@ array conv_transpose3d(
const std::tuple<int, int, int>& stride = {1, 1, 1}, const std::tuple<int, int, int>& stride = {1, 1, 1},
const std::tuple<int, int, int>& padding = {0, 0, 0}, const std::tuple<int, int, int>& padding = {0, 0, 0},
const std::tuple<int, int, int>& dilation = {1, 1, 1}, const std::tuple<int, int, int>& dilation = {1, 1, 1},
const std::tuple<int, int, int>& output_padding = {0, 0, 0},
int groups = 1, int groups = 1,
StreamOrDevice s = {}); StreamOrDevice s = {});

View File

@ -3056,6 +3056,7 @@ std::vector<array> QuantizedMatmul::vjp(
std::vector<array> vjps; std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple // We rely on the fact that w is always 2D so transpose is simple
std::optional<array> dsb = std::nullopt;
for (auto arg : argnums) { for (auto arg : argnums) {
// gradient wrt to x // gradient wrt to x
if (arg == 0) { if (arg == 0) {
@ -3071,9 +3072,34 @@ std::vector<array> QuantizedMatmul::vjp(
} }
// gradient wrt to w_q, scales or biases // gradient wrt to w_q, scales or biases
else { else if (arg == 1) {
throw std::runtime_error( throw std::runtime_error(
"[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); "[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto fc = flatten(cotangents[0], 0, -2, stream());
auto fx = flatten(primals[0], 0, -2, stream());
auto dw = transpose_
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
dsb = unflatten(dw, -1, {-1, group_size_}, stream());
}
if (arg == 3) {
// biases
vjps.push_back(sum(*dsb, -1, false, stream()));
} else {
// scales
auto s = stream();
auto wq = dequantize(
primals[1],
ones_like(primals[2], stream()),
zeros_like(primals[3], stream()),
group_size_,
bits_,
stream());
wq = unflatten(wq, -1, {-1, group_size_}, stream());
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
}
} }
} }
return vjps; return vjps;

View File

@ -223,7 +223,7 @@ array multivariate_normal(
auto n = mean.shape(-1); auto n = mean.shape(-1);
// Check shapes comatibility of mean and cov // Check shapes compatibility of mean and cov
if (cov.shape(-1) != cov.shape(-2)) { if (cov.shape(-1) != cov.shape(-2)) {
throw std::invalid_argument( throw std::invalid_argument(
"[multivariate_normal] last two dimensions of cov must be equal."); "[multivariate_normal] last two dimensions of cov must be equal.");
@ -402,7 +402,7 @@ array categorical(
if (broadcast_shapes(shape, reduced_shape) != shape) { if (broadcast_shapes(shape, reduced_shape) != shape) {
std::ostringstream msg; std::ostringstream msg;
msg << "[categorical] Requested shape " << shape msg << "[categorical] Requested shape " << shape
<< " is not broadcast compatable with reduced logits shape" << " is not broadcast compatible with reduced logits shape"
<< reduced_shape << "."; << reduced_shape << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@ -1,12 +1,13 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
namespace mlx::core { namespace mlx::core {
Stream default_stream(Device d) { Stream default_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) { if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument( throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend."); "[default_stream] Cannot get gpu stream without gpu backend.");
} }
@ -14,7 +15,7 @@ Stream default_stream(Device d) {
} }
void set_default_stream(Stream s) { void set_default_stream(Stream s) {
if (!metal::is_available() && s.device == Device::gpu) { if (!gpu::is_available() && s.device == Device::gpu) {
throw std::invalid_argument( throw std::invalid_argument(
"[set_default_stream] Cannot set gpu stream without gpu backend."); "[set_default_stream] Cannot set gpu stream without gpu backend.");
} }
@ -26,7 +27,7 @@ Stream get_stream(int index) {
} }
Stream new_stream(Device d) { Stream new_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) { if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument( throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend."); "[new_stream] Cannot make gpu stream without gpu backend.");
} }
@ -44,7 +45,7 @@ void synchronize(Stream s) {
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
f.wait(); f.wait();
} else { } else {
metal::synchronize(s); gpu::synchronize(s);
} }
} }

Some files were not shown because too many files have changed in this diff Show More