This commit is contained in:
paramthakkar123 2025-05-06 09:53:10 +05:30
commit 80002ed42f
128 changed files with 2291 additions and 895 deletions

1
.gitignore vendored
View File

@ -36,6 +36,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
# vim
*.swp

View File

@ -1,4 +1,6 @@
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@ -20,3 +20,5 @@ FFT
irfft2
rfftn
irfftn
fftshift
ifftshift

View File

@ -47,7 +47,10 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif()

View File

@ -356,7 +356,7 @@ class array {
}
enum Status {
// The ouptut of a computation which has not been scheduled.
// The output of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,

View File

@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m};
}
} // namespace mlx::core
} // namespace mlx::core

View File

@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return true;
}
} // namespace mlx::core::cpu

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cpu {
bool is_available();
} // namespace mlx::core::cpu

View File

@ -172,9 +172,12 @@ void binary_float(
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types.");
"[binary_float] Only supports floating point types.");
}
});
}

View File

@ -40,7 +40,10 @@ struct CompilerCache {
std::shared_mutex mtx;
};
static CompilerCache cache{};
static CompilerCache& cache() {
static CompilerCache cache_;
return cache_;
};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
@ -56,14 +59,16 @@ void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
std::shared_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
std::unique_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
@ -120,10 +125,10 @@ void* compile(
}
// load library
cache.libs.emplace_back(shared_lib_path);
cache().libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@ -131,7 +136,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
cache.kernels.insert({kernel_name, fun});
cache().kernels.insert({kernel_name, fun});
return fun;
}

View File

@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
scan_dispatch<complex64_t, complex64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
}
});

View File

@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log1p(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto x = in.value.real();
auto y = in.value.imag();
auto zabs = std::abs(in.value);
auto theta = std::atan2(y, x + 1);
if (zabs < 0.5) {
auto r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return Simd<T, 1>{T{x, theta}};
}
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
} else {
auto z0 = std::hypot(x + 1, y);
return Simd<T, 1>{T{std::log(z0), theta}};
}
} else {
return Simd<T, 1>{std::log1p(in.value)};
}
}
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {

View File

@ -0,0 +1,5 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)

View File

@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::gpu {
bool is_available();
} // namespace mlx::core::gpu

49
mlx/backend/gpu/copy.cpp Normal file
View File

@ -0,0 +1,49 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
} // namespace mlx::core

View File

@ -5,6 +5,8 @@
#include "mlx/backend/common/copy.h"
#include "mlx/stream.h"
#include <optional>
namespace mlx::core {
// Generic copy inplace

View File

@ -8,14 +8,11 @@
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::metal {
namespace mlx::core::gpu {
void new_stream(Stream stream);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
void eval(array& arr);
void finalize(Stream s);
void synchronize(Stream s);
} // namespace mlx::core::metal
} // namespace mlx::core::gpu

View File

@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <cassert>
#define MLX_PROFILER_RANGE(message)
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsType::eval_gpu");
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Copy::eval_gpu");
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Depends::eval_gpu");
eval(inputs, outputs);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
eval(inputs, out);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Full::eval_gpu");
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Split::eval_gpu");
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Transpose::eval_gpu");
eval(inputs, out);
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("View::eval_gpu");
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@ -0,0 +1,44 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@ -93,6 +93,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp

View File

@ -1,7 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"

View File

@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = 1;
work_per_thread = get_work_per_thread(a.dtype());
}
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
@ -137,13 +137,20 @@ void binary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@ -64,6 +64,7 @@ inline void build_kernel(
cnt++);
}
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (add_indices) {
os += fmt::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
@ -83,6 +84,9 @@ inline void build_kernel(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else {
os += fmt::format(
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
@ -92,13 +96,14 @@ inline void build_kernel(
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "int64_t" : "uint";
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
} else if (contiguous) {
os += " uint index = N_ * pos.x;\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
@ -110,6 +115,9 @@ inline void build_kernel(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
if (work_per_thread > 1 && contiguous) {
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
}
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
@ -193,7 +201,7 @@ inline void build_kernel(
}
// Open per-thread loop
if (work_per_thread > 1) {
if (work_per_thread > 1 && !contiguous) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
@ -272,6 +280,7 @@ void Compiled::eval_gpu(
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() {
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
@ -284,7 +293,9 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false);
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
@ -295,7 +306,8 @@ void Compiled::eval_gpu(
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true);
/* use_big_index = */ true,
/* work_per_thread = */ work_per_thread);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
@ -468,6 +480,13 @@ void Compiled::eval_gpu(
if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
} else {
auto size = outputs[0].data_size();
if (large) {
compute_encoder.set_bytes<int64_t>(size, cnt++);
} else {
compute_encoder.set_bytes<int>(size, cnt++);
}
}
// Put the number of dims in if it is dynamic
@ -477,12 +496,13 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].data_size();
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
? get_2d_grid_dims(
outputs[0].shape(), outputs[0].strides(), work_per_thread)
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {

View File

@ -5,7 +5,7 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

View File

@ -1,35 +1,15 @@
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
@ -104,6 +84,8 @@ void copy_gpu_inplace(
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
}
}
} else {
work_per_thread = get_work_per_thread(in.dtype());
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@ -165,39 +147,23 @@ void copy_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
@ -214,14 +180,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@ -1,6 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"

View File

@ -1,20 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
namespace {
@ -66,8 +66,8 @@ MTL::Library* try_load_bundle(
if (bundle != nullptr) {
std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
lib_name + ".metallib" auto [lib, error] =
load_library_from_path(device, resource_path.c_str());
lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
}
@ -79,12 +79,18 @@ MTL::Library* try_load_bundle(
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& lib_name) {
std::string lib_path = get_colocated_mtllib_path(lib_name);
if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
const std::string& relative_path) {
std::string binary_dir = get_binary_directory();
if (binary_dir.size() == 0) {
return {nullptr, nullptr};
}
return {nullptr, nullptr};
auto path = fs::path(binary_dir) / relative_path;
if (!path.has_extension()) {
path.replace_extension(".metallib");
}
return load_library_from_path(device, path.c_str());
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
@ -99,7 +105,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
library = try_load_bundle(device, bundle->resourceURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
@ -109,33 +115,34 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error *error1, *error2, *error3;
NS::Error* error[4];
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
if (error1 != nullptr) {
msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
for (int i = 0; i < 4; i++) {
if (error[i] != nullptr) {
msg << error[i]->localizedDescription()->utf8String() << " ";
}
}
throw std::runtime_error(msg.str());
}
@ -156,6 +163,7 @@ MTL::Library* load_library(
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
return lib;
}
// We have been given a path so try to load from lib_path / lib_name.metallib
@ -168,6 +176,7 @@ MTL::Library* load_library(
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
return lib;
}
// Try to load the colocated library
@ -188,8 +197,8 @@ MTL::Library* load_library(
std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
<< ">";
<< "We attempted to load it from <" << get_binary_directory() << "/"
<< lib_name << ".metallib" << ">";
#ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle.";
#endif
@ -760,42 +769,4 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}
}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;
}
} // namespace mlx::core::metal

View File

@ -21,18 +21,14 @@ namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not
// dynamically linked.
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
inline std::string get_binary_directory() {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
std::string directory;
int success = dladdr((void*)get_binary_directory, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
directory = fs::path(info.dli_fname).remove_filename().c_str();
}
return mtllib_path;
return directory;
}
using MTLFCList =
@ -270,4 +266,6 @@ class Device {
Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
} // namespace mlx::core::metal

View File

@ -4,7 +4,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h"

102
mlx/backend/metal/eval.cpp Normal file
View File

@ -0,0 +1,102 @@
// Copyright © 2023-2024 Apple Inc.
#include <memory>
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
metal::device(stream.device).new_queue(stream.index);
}
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
void eval(array& arr) {
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
} // namespace mlx::core::gpu

View File

@ -2,7 +2,6 @@
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h"
namespace mlx::core {

View File

@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
@ -139,7 +138,7 @@ void Fence::update(Stream stream, const array& x) {
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Barrier on previous kernels
compute_encoder.barrier();

View File

@ -7,10 +7,10 @@
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"

View File

@ -1,11 +1,9 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
@ -15,7 +13,6 @@
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) {
return source.str();
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
void hadamard_mn_contiguous(
const array& x,
array& y,
int m,
int n1,
int n2,
float scale,
metal::Device& d,
const Stream& s) {
int n = n1 * n2;
int read_width_n1 = n1 == 2 ? 2 : 4;
int read_width_n2 = n2 == 2 ? 2 : 4;
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
int max_radix_1 = std::min(n1, 16);
int max_radix_2 = std::min(n2, 16);
float scale_n1 = 1.0;
float scale_n2 = (m == 1) ? scale : 1.0;
float scale_m = scale;
auto& in = inputs[0];
// n2 is a row contiguous power of 2 hadamard transform
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
// n1 is a strided power of 2 hadamard transform with stride n2
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
// m is a strided hadamard transform with stride n = n1 * n2
MTL::Size group_dims_m(
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
MTL::Size grid_dims_m(
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
// Make the kernel
std::string kname;
kname.reserve(32);
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
auto lib = d.get_library(kname, [&]() {
std::string kernel;
concatenate(
kernel,
metal::utils(),
gen_hadamard_codelet(m),
metal::hadamard(),
get_template_definition(
"n2" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n2,
max_radix_2,
read_width_n2));
if (n1 > 1) {
kernel += get_template_definition(
"n1" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n1,
max_radix_1,
read_width_n1,
n2);
}
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.copy_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
return kernel_source.str();
if (m > 1) {
kernel += get_template_definition(
"m" + kname,
"hadamard_m",
get_type_string(x.dtype()),
n,
m,
read_width_m);
}
return kernel;
});
int batch_size = in.size() / n;
int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
// Launch the strided transform for n1
if (n1 > 1) {
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n1" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n1, 2);
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
}
d.add_temporaries(std::move(copies), s.index);
// Launch the transform for n2
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n2" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n2, 2);
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
// Launch the strided transform for m
if (m > 1) {
auto kernel = d.get_kernel("m" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(y, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_m, 2);
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
}
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
// Split the hadamard transform so that all of them work on vectors smaller
// than 8192 elements.
//
// We decompose it in the following way:
//
// n = m * n1 * n2 = m * 2^k1 * 2^k2
//
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
auto [n, m] = decompose_hadamard(in.shape().back());
int n1 = 1, n2 = n;
if (n > 8192) {
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
}
n1 = n / n2;
}
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
} else {
copy_gpu(in, out, CopyType::General, s);
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
}
}
} // namespace mlx::core

View File

@ -2,7 +2,7 @@
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"

View File

@ -9,64 +9,85 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[0], b[offset]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[0]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)

View File

@ -130,6 +130,24 @@ struct LogAddExp {
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
metal::isnan(y.imag)) {
return metal::numeric_limits<float>::quiet_NaN();
}
constexpr float inf = metal::numeric_limits<float>::infinity();
complex64_t maxval = x > y ? x : y;
complex64_t minval = x < y ? x : y;
if (minval.real == -inf || maxval.real == inf)
return maxval;
float m = metal::exp(minval.real - maxval.real);
complex64_t dexp{
m * metal::cos(minval.imag - maxval.imag),
m * metal::sin(minval.imag - maxval.imag),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@ -12,82 +12,103 @@ template <typename T, typename U, typename Op>
d[index] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[index]);
c[index] = out[0];
d[index] = out[1];
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[0]);
c[index] = out[0];
d[index] = out[1];
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[index]);
c[index] = out[0];
d[index] = out[1];
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>

View File

@ -1,39 +1,53 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
}
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant uint& size,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
}
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[0]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
}
template <typename T, typename U>
template <typename T, typename U, int N = WorkPerThread<T>::n>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
}
template <typename T, typename U, typename IdxT = int64_t>

View File

@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so
read/write performance is important.
Where possible, we read 128 bits sequentially in each thread,
coalesced with accesses from adajcent threads for optimal performance.
coalesced with accesses from adjacent threads for optimal performance.
We implement specialized reading/writing for:
- FFT

View File

@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) {
}
}
template <typename T, int N, int max_radix, int read_width>
template <typename T, int N, int max_radix, int read_width, int stride = 1>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@ -46,18 +46,25 @@ template <typename T, int N, int max_radix, int read_width>
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N;
short i = elem.y;
int batch_idx = elem.y * N * stride + elem.z;
short i = elem.x;
threadgroup T buf[N];
// Read values from device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
if (stride == 1) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
}
}
@ -113,12 +120,20 @@ template <typename T, int N, int max_radix, int read_width>
}
// Write values to device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
if (stride == 1) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix; j++) {
out[batch_idx + (j * num_threads + i) * stride] =
buf[j * num_threads + i];
}
}
}

View File

@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
auto wl = (const device uint8_t*)w;
x += y_row * K;
x += y_row * static_cast<int64_t>(K);
wl += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
x += y_row * static_cast<int64_t>(K);
wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);

View File

@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on

View File

@ -56,9 +56,9 @@ template <typename T, int D, int V = D>
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
@ -213,9 +213,9 @@ template <typename T, int D, int V = D>
const int block_idx = tid.z;
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread;
@ -358,8 +358,8 @@ template <typename T, int D>
// Adjust positions
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int n_heads = tpg.x;
const int q_offset = n_heads * q_seq_idx + head_idx;
const int q_offset = head_idx * tpg.y + q_seq_idx;
;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks;
maxs += q_offset * blocks;

View File

@ -95,7 +95,7 @@ template <
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
tidl.x * BQ * params->Q_strides[2]; // Sequence
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
@ -106,7 +106,7 @@ template <
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
tidl.x * BQ * params->O_strides[2]; // Sequence
if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch

View File

@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
// Zero out unneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
@ -240,7 +240,7 @@ struct BlockLoaderT {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
// Zero out unneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general(
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
// Adjust for simdgroup and thread location
int offset_m = c_row + mma_op.sm;
int offset_n = c_col + mma_op.sn;
C += offset_n;

View File

@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
// Zero out unneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@ -1,25 +1,32 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op>
template <typename T, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
d[index] = Op()(a[index], b[index], c[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
}
template <typename T, typename Op>
template <typename T, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
d[offset] = Op()(a[offset], b[offset], c[offset]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
}
template <typename T, typename Op, typename IdxT = int64_t>

View File

@ -1,21 +1,28 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void unary_v(
device const T* in,
device U* out,
constant uint& size,
uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]);
index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void unary_v2(
device const T* in,
device U* out,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
out[offset] = Op()(in[offset]);
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
}
template <

View File

@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)

View File

@ -15,6 +15,14 @@
typedef half float16_t;
// Work per thread values for different types. The values here are expected to
// match get_work_per_thread in mlx/backend/metal/utils.h
template <typename U>
struct WorkPerThread {
static_assert(sizeof(U) <= 8, "Type too large");
static constexpr int constant n = 8 / sizeof(U);
};
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
@ -328,6 +336,23 @@ inline bfloat16_t log1p(bfloat16_t x) {
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
inline complex64_t log1p(complex64_t in) {
float x = in.real;
float y = in.imag;
float zabs = metal::precise::sqrt(x * x + y * y);
float theta = metal::atan2(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1p(r), theta};
} else {
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
return {metal::log(z0), theta};
}
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@ -7,7 +7,7 @@
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

View File

@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <memory>
#include <sys/sysctl.h>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::metal {
@ -13,85 +13,6 @@ bool is_available() {
return true;
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
void eval(array& arr) {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
void start_capture(std::string path, id object) {
auto pool = new_scoped_memory_pool();
@ -128,4 +49,36 @@ void stop_capture() {
manager->stopCapture();
}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;
}
} // namespace mlx::core::metal

View File

@ -2,11 +2,10 @@
#pragma once
#include <string>
#include <unordered_map>
#include <variant>
#include "mlx/array.h"
namespace mlx::core::metal {
/* Check if the Metal backend is available. */

View File

@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
namespace mlx::core::metal {
bool is_available() {
return false;
}
void start_capture(std::string) {}
void stop_capture() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal

View File

@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/reduce.h"

View File

@ -7,10 +7,10 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1);
}
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
static array compute_dynamic_offset(
const array& indices,
const Strides& strides,
@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const Stream& s = */ stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void QRF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@ -537,35 +390,4 @@ void LUF::eval_gpu(
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@ -4,7 +4,7 @@
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"

View File

@ -3,7 +3,7 @@
#include <algorithm>
#include <cassert>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

View File

@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/resident.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {

View File

@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"

View File

@ -2,7 +2,7 @@
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
@ -154,9 +154,9 @@ void sdpa_vector(
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1];
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1);
@ -199,11 +199,10 @@ void sdpa_vector(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
int32_t head_stride =
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
@ -238,9 +237,10 @@ void sdpa_vector_2pass(
int N = k.shape(2);
int blocks = 32;
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.strides()[1];
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.strides()[1];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(8 * 32, 1, 1);
MTL::Size grid_dims(B, q.shape(2), blocks);
@ -302,11 +302,10 @@ void sdpa_vector_2pass(
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0;
int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0;
int32_t head_stride =
m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0);
compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks if arr is row contiguous or the sequence and head dimension are
// transposed
auto is_contiguous_or_head_seq_transposed = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) &&
(strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]);
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query
if (q_pre.shape(2) <= 8) {
const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
// Donate the query if possible
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
q.size() == o.size()) {
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
if (o.shape(2) == 1) {
o.set_data(allocator::malloc(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
strides[1] = o.shape(3);
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
}
o.set_data(allocator::malloc(o.nbytes()));
}
auto mask =
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
auto mask_copy_unless = [&q](const array& arr) {
auto& strides = arr.strides();
auto& shape = arr.shape();
return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 ||
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long

View File

@ -3,7 +3,7 @@
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@ -2,21 +2,12 @@
#include <numeric>
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
@ -48,30 +39,4 @@ void concatenate_gpu(
}
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"

View File

@ -2,7 +2,7 @@
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

View File

@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = 1;
work_per_thread = get_work_per_thread(b.dtype());
}
std::string kernel_name;
if (topt == TernaryOpType::General) {
@ -106,13 +106,19 @@ void ternary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 4);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 4);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@ -34,18 +34,19 @@ void unary_op_gpu_inplace(
};
auto [shape, strides] = maybe_collapse();
int ndim = shape.size();
size_t nthreads = contig ? in.data_size() : in.size();
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
int work_per_thread = !contig && large ? 4 : 1;
int work_per_thread;
std::string kernel_name;
if (contig) {
work_per_thread = get_work_per_thread(in.dtype());
kernel_name = (large ? "v2" : "v");
} else {
work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread);
if (large) {
kernel_name += "large";
@ -75,12 +76,20 @@ void unary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t nthreads = ceildiv(in.data_size(), work_per_thread);
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(in.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(in.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) {
concatenate(acc, args...);
}
inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m;
}
} // namespace mlx::core

View File

@ -1,6 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return false;
}
} // namespace mlx::core::cpu

View File

@ -18,7 +18,7 @@ void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error(
"[Compiled::eval_cpu] CPU compialtion not supported on the platform.");
"[Compiled::eval_cpu] CPU compilation not supported on the platform.");
}
} // namespace mlx::core

View File

@ -3,5 +3,5 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)

View File

@ -6,9 +6,9 @@
#include "mlx/allocator.h"
#ifdef __APPLE__
#include "mlx/backend/no_metal/apple_memory.h"
#include "mlx/backend/no_gpu/apple_memory.h"
#elif defined(__linux__)
#include "mlx/backend/no_metal/linux_memory.h"
#include "mlx/backend/no_gpu/linux_memory.h"
#else
size_t get_memory_size() {
return 0;

View File

@ -0,0 +1,28 @@
// Copyright © 2025 Apple Inc.
#include <stdexcept>
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
namespace mlx::core::gpu {
bool is_available() {
return false;
}
void new_stream(Stream) {}
void eval(array&) {
throw std::runtime_error("[gpu::eval] GPU backend is not available");
}
void finalize(Stream) {
throw std::runtime_error("[gpu::finalize] GPU backend is not available");
}
void synchronize(Stream) {
throw std::runtime_error("[gpu::synchronize] GPU backend is not available");
}
} // namespace mlx::core::gpu

View File

@ -1,43 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {
bool is_available() {
return false;
}
void new_stream(Stream) {}
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
return nullptr;
}
void eval(array&) {
throw std::runtime_error(
"[metal::eval] Cannot eval on GPU without metal backend");
}
void finalize(Stream) {
throw std::runtime_error(
"[metal::finalize] Cannot finalize GPU without metal backend");
}
void synchronize(Stream) {
throw std::runtime_error(
"[metal::synchronize] Cannot synchronize GPU without metal backend");
}
void start_capture(std::string) {}
void stop_capture() {}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal

View File

@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
parent.first.inputs()[parent.second] = dst;
pairs.push_back(parent);
}
// If src is a parent of dst, remove it from dst's parents
for (auto it = pairs.begin(); it != pairs.end();) {
if (it->first.id() == src.id()) {
it = pairs.erase(it);
} else {
it++;
}
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}

View File

@ -1,23 +1,28 @@
// Copyright © 2023 Apple Inc.
#include <stdexcept>
#include "mlx/backend/cpu/available.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/device.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core {
static Device default_device_{
metal::is_available() ? Device::gpu : Device::cpu};
Device& mutable_default_device() {
static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
return default_device;
}
const Device& default_device() {
return default_device_;
return mutable_default_device();
}
void set_default_device(const Device& d) {
if (!metal::is_available() && d == Device::gpu) {
if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[set_default_device] Cannot set gpu device without gpu backend.");
}
default_device_ = d;
mutable_default_device() = d;
}
bool operator==(const Device& lhs, const Device& rhs) {
@ -28,4 +33,15 @@ bool operator!=(const Device& lhs, const Device& rhs) {
return !(lhs == rhs);
}
bool is_available(const Device& d) {
switch (d.type) {
case Device::cpu:
return cpu::is_available();
case Device::gpu:
return gpu::is_available();
}
// appease compiler
return false;
}
} // namespace mlx::core

View File

@ -26,4 +26,6 @@ void set_default_device(const Device& d);
bool operator==(const Device& lhs, const Device& rhs);
bool operator!=(const Device& lhs, const Device& rhs);
bool is_available(const Device& d);
} // namespace mlx::core

View File

@ -470,6 +470,9 @@ bool FunctionTable::match(
if (x.dtype() != y.dtype()) {
return false;
}
if (x.ndim() != y.ndim()) {
return false;
}
if (!shapeless && x.shape() != y.shape()) {
return false;
}

View File

@ -186,6 +186,7 @@ array irfftn(
StreamOrDevice s /* = {} */) {
return fft_impl(a, axes, true, true, s);
}
array irfftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, true, true, s);
}
@ -308,4 +309,73 @@ array istft(
return signal;
}
array fftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
Shape shifts;
for (int ax : axes) {
// Convert negative axes to positive
int axis = ax < 0 ? ax + a.ndim() : ax;
if (axis < 0 || axis >= a.ndim()) {
std::ostringstream msg;
msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Match NumPy's implementation
shifts.push_back(a.shape(axis) / 2);
}
return roll(a, shifts, axes, s);
}
array ifftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
Shape shifts;
for (int ax : axes) {
// Convert negative axes to positive
int axis = ax < 0 ? ax + a.ndim() : ax;
if (axis < 0 || axis >= a.ndim()) {
std::ostringstream msg;
msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Match NumPy's implementation
int size = a.shape(axis);
shifts.push_back(-(size / 2));
}
return roll(a, shifts, axes, s);
}
// Default versions that operate on all axes
array fftshift(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() < 1) {
return a;
}
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return fftshift(a, axes, s);
}
array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() < 1) {
return a;
}
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return ifftshift(a, axes, s);
}
} // namespace mlx::core::fft

View File

@ -148,6 +148,24 @@ inline array irfft2(
StreamOrDevice s = {}) {
return irfftn(a, axes, s);
}
/** Shift the zero-frequency component to the center of the spectrum. */
array fftshift(const array& a, StreamOrDevice s = {});
/** Shift the zero-frequency component to the center of the spectrum along
* specified axes. */
array fftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** The inverse of fftshift. */
array ifftshift(const array& a, StreamOrDevice s = {});
/** The inverse of fftshift along specified axes. */
array ifftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array stft(
const array& x,

View File

@ -335,7 +335,10 @@ ThreadPool& thread_pool() {
return pool_;
}
ThreadPool ParallelFileReader::thread_pool_{4};
ThreadPool& ParallelFileReader::thread_pool() {
static ThreadPool thread_pool{4};
return thread_pool;
}
void ParallelFileReader::read(char* data, size_t n) {
while (n != 0) {
@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
break;
} else {
size_t m = batch_size_;
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
futs.emplace_back(
ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data));
data += m;
n -= m;
offset += m;

View File

@ -101,7 +101,7 @@ class ParallelFileReader : public Reader {
private:
static constexpr size_t batch_size_ = 1 << 25;
static ThreadPool thread_pool_;
static ThreadPool& thread_pool();
int fd_;
std::string label_;
};

View File

@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
// Prepare S
S = expand_dims(S, -2, s);
return matmul(divide(V, S, s), U);
auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps;
auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s);
auto rS =
where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s);
return matmul(multiply(V, rS, s), U, s);
}
array cholesky_inv(

View File

@ -473,8 +473,19 @@ array hadamard_transform(
std::optional<float> scale_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
// Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N)
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1));
int n = a.ndim() > 0 ? a.shape(-1) : 1;
float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n);
auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32;
// Nothing to do for a scalar
if (n == 1) {
if (scale == 1) {
return a;
}
return multiply(a, array(scale, dtype), s);
}
return array(
a.shape(),
dtype,
@ -3769,6 +3780,7 @@ array conv_transpose_general(
std::vector<int> stride,
std::vector<int> padding,
std::vector<int> dilation,
std::vector<int> output_padding,
int groups,
StreamOrDevice s) {
std::vector<int> padding_lo(padding.size());
@ -3782,7 +3794,8 @@ array conv_transpose_general(
int in_size = 1 + (conv_output_shape - 1);
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding[i];
padding_hi[i] = in_size - out_size + padding[i] +
output_padding[i]; // Adjust with output_padding
}
return conv_general(
@ -3805,10 +3818,11 @@ array conv_transpose1d(
int stride /* = 1 */,
int padding /* = 0 */,
int dilation /* = 1 */,
int output_padding /* = 0 */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s);
}
/** 2D transposed convolution with a filter */
@ -3818,6 +3832,7 @@ array conv_transpose2d(
const std::pair<int, int>& stride /* = {1, 1} */,
const std::pair<int, int>& padding /* = {0, 0} */,
const std::pair<int, int>& dilation /* = {1, 1} */,
const std::pair<int, int>& output_padding /* = {0, 0} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
@ -3826,6 +3841,7 @@ array conv_transpose2d(
{stride.first, stride.second},
{padding.first, padding.second},
{dilation.first, dilation.second},
{output_padding.first, output_padding.second},
groups,
s);
}
@ -3837,6 +3853,7 @@ array conv_transpose3d(
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
const std::tuple<int, int, int>& output_padding /* = {0, 0, 0} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
@ -3845,6 +3862,9 @@ array conv_transpose3d(
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
{std::get<0>(output_padding),
std::get<1>(output_padding),
std::get<2>(output_padding)},
groups,
s);
}
@ -4873,8 +4893,9 @@ array bitwise_impl(
const array& b,
BitwiseBinary::Op op,
const std::string& op_name,
const StreamOrDevice& s) {
auto out_type = promote_types(a.dtype(), b.dtype());
const StreamOrDevice& s,
std::optional<Dtype> out_type_ = std::nullopt) {
auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype());
if (!(issubdtype(out_type, integer) || out_type == bool_)) {
std::ostringstream msg;
msg << "[" << op_name
@ -4919,12 +4940,7 @@ array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (t == bool_) {
t = uint8;
}
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::LeftShift,
"left_shift",
s);
return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t);
}
array operator<<(const array& a, const array& b) {
return left_shift(a, b);
@ -4940,7 +4956,8 @@ array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
astype(b, t, s),
BitwiseBinary::Op::RightShift,
"right_shift",
s);
s,
t);
}
array operator>>(const array& a, const array& b) {
return right_shift(a, b);
@ -5019,8 +5036,11 @@ array roll(
}
auto sh = shift[i];
auto split_index =
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax);
auto size = a.shape(ax);
if (size == 0) {
continue; // skip rolling this axis if it has size 0
}
auto split_index = (sh < 0) ? (-sh) % size : size - sh % size;
auto parts = split(result, Shape{split_index}, ax, s);
std::swap(parts[0], parts[1]);

View File

@ -569,7 +569,7 @@ inline array std(const array& a, StreamOrDevice s = {}) {
return std(a, false, 0, to_stream(s));
}
/** Computes the standard deviatoin of the elements of an array along the given
/** Computes the standard deviation of the elements of an array along the given
* axes */
array std(
const array& a,
@ -1291,6 +1291,7 @@ array conv_transpose1d(
int stride = 1,
int padding = 0,
int dilation = 1,
int output_padding = 0,
int groups = 1,
StreamOrDevice s = {});
@ -1301,6 +1302,7 @@ array conv_transpose2d(
const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1},
const std::pair<int, int>& output_padding = {0, 0},
int groups = 1,
StreamOrDevice s = {});
@ -1311,6 +1313,7 @@ array conv_transpose3d(
const std::tuple<int, int, int>& stride = {1, 1, 1},
const std::tuple<int, int, int>& padding = {0, 0, 0},
const std::tuple<int, int, int>& dilation = {1, 1, 1},
const std::tuple<int, int, int>& output_padding = {0, 0, 0},
int groups = 1,
StreamOrDevice s = {});

View File

@ -3056,6 +3056,7 @@ std::vector<array> QuantizedMatmul::vjp(
std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple
std::optional<array> dsb = std::nullopt;
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
@ -3071,9 +3072,34 @@ std::vector<array> QuantizedMatmul::vjp(
}
// gradient wrt to w_q, scales or biases
else {
else if (arg == 1) {
throw std::runtime_error(
"[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet.");
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto fc = flatten(cotangents[0], 0, -2, stream());
auto fx = flatten(primals[0], 0, -2, stream());
auto dw = transpose_
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
dsb = unflatten(dw, -1, {-1, group_size_}, stream());
}
if (arg == 3) {
// biases
vjps.push_back(sum(*dsb, -1, false, stream()));
} else {
// scales
auto s = stream();
auto wq = dequantize(
primals[1],
ones_like(primals[2], stream()),
zeros_like(primals[3], stream()),
group_size_,
bits_,
stream());
wq = unflatten(wq, -1, {-1, group_size_}, stream());
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
}
}
}
return vjps;

View File

@ -223,7 +223,7 @@ array multivariate_normal(
auto n = mean.shape(-1);
// Check shapes comatibility of mean and cov
// Check shapes compatibility of mean and cov
if (cov.shape(-1) != cov.shape(-2)) {
throw std::invalid_argument(
"[multivariate_normal] last two dimensions of cov must be equal.");
@ -402,7 +402,7 @@ array categorical(
if (broadcast_shapes(shape, reduced_shape) != shape) {
std::ostringstream msg;
msg << "[categorical] Requested shape " << shape
<< " is not broadcast compatable with reduced logits shape"
<< " is not broadcast compatible with reduced logits shape"
<< reduced_shape << ".";
throw std::invalid_argument(msg.str());
}

View File

@ -1,12 +1,13 @@
// Copyright © 2023 Apple Inc.
#include "mlx/scheduler.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
namespace mlx::core {
Stream default_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend.");
}
@ -14,7 +15,7 @@ Stream default_stream(Device d) {
}
void set_default_stream(Stream s) {
if (!metal::is_available() && s.device == Device::gpu) {
if (!gpu::is_available() && s.device == Device::gpu) {
throw std::invalid_argument(
"[set_default_stream] Cannot set gpu stream without gpu backend.");
}
@ -26,7 +27,7 @@ Stream get_stream(int index) {
}
Stream new_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
@ -44,7 +45,7 @@ void synchronize(Stream s) {
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
f.wait();
} else {
metal::synchronize(s);
gpu::synchronize(s);
}
}

Some files were not shown because too many files have changed in this diff Show More