diff --git a/.gitignore b/.gitignore index e748ee2bf..43629548d 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +uv.lock # vim *.swp diff --git a/MANIFEST.in b/MANIFEST.in index 9faafee45..d0daeb7ae 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,6 @@ include CMakeLists.txt +include mlx.pc.in recursive-include mlx/ * +include cmake/* include python/src/* include python/mlx/py.typed # support type hinting as in PEP-561 diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst index 9e4be084b..36d9d7838 100644 --- a/docs/src/python/fft.rst +++ b/docs/src/python/fft.rst @@ -20,3 +20,5 @@ FFT irfft2 rfftn irfftn + fftshift + ifftshift diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index abf46a7d5..00898e73e 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,7 +47,10 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/array.h b/mlx/array.h index 66a4702a6..d9fcfc58e 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -356,7 +356,7 @@ class array { } enum Status { - // The ouptut of a computation which has not been scheduled. + // The output of a computation which has not been scheduled. // For example, the status of `x` in `auto x = a + b`. unscheduled, diff --git a/mlx/backend/common/hadamard.h b/mlx/backend/common/hadamard.h index a8fed76b0..ba5c4e41e 100644 --- a/mlx/backend/common/hadamard.h +++ b/mlx/backend/common/hadamard.h @@ -99,7 +99,11 @@ inline std::pair decompose_hadamard(int n) { "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); } } + if (n > (1 << 26)) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where k <= 26"); + } return {n, m}; } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 152f33b17..96b3f1313 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble) target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp diff --git a/mlx/backend/cpu/available.cpp b/mlx/backend/cpu/available.cpp new file mode 100644 index 000000000..0449d49b9 --- /dev/null +++ b/mlx/backend/cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return true; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/available.h b/mlx/backend/cpu/available.h new file mode 100644 index 000000000..1df95def2 --- /dev/null +++ b/mlx/backend/cpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cpu { + +bool is_available(); + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index dbdab6a06..35aa2a3e0 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -172,9 +172,12 @@ void binary_float( case bfloat16: binary_op(a, b, out, bopt); break; + case complex64: + binary_op(a, b, out, bopt); + break; default: throw std::runtime_error( - "[binary_float] Only supports non-complex floating point types."); + "[binary_float] Only supports floating point types."); } }); } diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 9da9c14e8..e389e0df5 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -40,7 +40,10 @@ struct CompilerCache { std::shared_mutex mtx; }; -static CompilerCache cache{}; +static CompilerCache& cache() { + static CompilerCache cache_; + return cache_; +}; // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is also available. @@ -56,14 +59,16 @@ void* compile( const std::string& kernel_name, const std::function& source_builder) { { - std::shared_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::shared_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } } - std::unique_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::unique_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } std::string source_code = source_builder(); @@ -120,10 +125,10 @@ void* compile( } // load library - cache.libs.emplace_back(shared_lib_path); + cache().libs.emplace_back(shared_lib_path); // Load function - void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); + void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str()); if (!fun) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to load compiled function " @@ -131,7 +136,7 @@ void* compile( << dlerror(); throw std::runtime_error(msg.str()); } - cache.kernels.insert({kernel_name, fun}); + cache().kernels.insert({kernel_name, fun}); return fun; } diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 199dbab35..33addd161 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { reduce_type_, in, out, axis_, reverse_, inclusive_); break; case complex64: - throw std::runtime_error("Scan ops do not support complex types yet"); + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); break; } }); diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 7e82a4d56..17cd35b9a 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log10, std::log10) -DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd log1p(Simd in) { + if constexpr (is_complex) { + auto x = in.value.real(); + auto y = in.value.imag(); + auto zabs = std::abs(in.value); + auto theta = std::atan2(y, x + 1); + if (zabs < 0.5) { + auto r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return Simd{T{x, theta}}; + } + return Simd{T{((typeof(x))(0.5)) * std::log1p(r), theta}}; + } else { + auto z0 = std::hypot(x + 1, y); + return Simd{T{std::log(z0), theta}}; + } + } else { + return Simd{std::log1p(in.value)}; + } +} + template Simd log2(Simd in) { if constexpr (is_complex) { diff --git a/mlx/backend/gpu/CMakeLists.txt b/mlx/backend/gpu/CMakeLists.txt new file mode 100644 index 000000000..0396ae03a --- /dev/null +++ b/mlx/backend/gpu/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) diff --git a/mlx/backend/gpu/available.h b/mlx/backend/gpu/available.h new file mode 100644 index 000000000..476c7acf2 --- /dev/null +++ b/mlx/backend/gpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::gpu { + +bool is_available(); + +} // namespace mlx::core::gpu diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp new file mode 100644 index 000000000..6127ac921 --- /dev/null +++ b/mlx/backend/gpu/copy.cpp @@ -0,0 +1,49 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + +void copy_gpu(const array& in, array& out, CopyType ctype) { + copy_gpu(in, out, ctype, out.primitive().stream()); +} + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); +} + +void copy_gpu_inplace( + const array& in, + array& out, + const Strides& i_strides, + int64_t i_offset, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/copy.h b/mlx/backend/gpu/copy.h similarity index 98% rename from mlx/backend/metal/copy.h rename to mlx/backend/gpu/copy.h index 37c60df42..020f579e4 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/gpu/copy.h @@ -5,6 +5,8 @@ #include "mlx/backend/common/copy.h" #include "mlx/stream.h" +#include + namespace mlx::core { // Generic copy inplace diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/gpu/eval.h similarity index 63% rename from mlx/backend/metal/metal_impl.h rename to mlx/backend/gpu/eval.h index 9ca8d2f80..f646c2ec9 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/gpu/eval.h @@ -8,14 +8,11 @@ #include "mlx/array.h" #include "mlx/stream.h" -namespace mlx::core::metal { +namespace mlx::core::gpu { void new_stream(Stream stream); - -std::unique_ptr> new_scoped_memory_pool(); - void eval(array& arr); void finalize(Stream s); void synchronize(Stream s); -} // namespace mlx::core::metal +} // namespace mlx::core::gpu diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp new file mode 100644 index 000000000..cd9296075 --- /dev/null +++ b/mlx/backend/gpu/primitives.cpp @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/primitives.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +#include + +#define MLX_PROFILER_RANGE(message) + +namespace mlx::core { + +namespace { + +void reshape(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +} // namespace + +void AsStrided::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsStrided::eval_gpu"); + eval(inputs, out); +} + +void AsType::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsType::eval_gpu"); + CopyType ctype = + inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(inputs[0], out, ctype); +} + +void Broadcast::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Broadcast::eval_gpu"); + eval(inputs, out); +} + +void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu"); + eval(inputs, out); +} + +void Concatenate::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Concatenate::eval_gpu"); + concatenate_gpu(inputs, out, axis_, stream()); +} + +void Contiguous::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Contiguous::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { + out.copy_shared_buffer(in); + } else { + copy_gpu(in, out, CopyType::General); + } +} + +void Copy::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Copy::eval_gpu"); + eval(inputs, out); +} + +void CustomTransforms::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("CustomTransforms::eval_gpu"); + eval(inputs, outputs); +} + +void Depends::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Depends::eval_gpu"); + eval(inputs, outputs); +} + +void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); + eval(inputs, out); +} + +void Full::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Full::eval_gpu"); + auto in = inputs[0]; + CopyType ctype; + if (in.data_size() == 1) { + ctype = CopyType::Scalar; + } else if (in.flags().contiguous) { + ctype = CopyType::Vector; + } else { + ctype = CopyType::General; + } + copy_gpu(in, out, ctype); +} + +void Flatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Flatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("NumberOfElements::eval_gpu"); + eval(inputs, out); +} + +void Pad::eval_gpu(const std::vector& inputs, array& out) { + // Inputs must be base input array and scalar val array + assert(inputs.size() == 2); + auto& in = inputs[0]; + auto& val = inputs[1]; + + // Padding value must be a scalar + assert(val.size() == 1); + + // Padding value, input and output must be of the same type + assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); + + pad_gpu(in, val, out, axes_, low_pad_size_, stream()); +} + +void Reshape::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Reshape::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void Split::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Split::eval_gpu"); + eval(inputs, outputs); +} + +void Slice::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Slice::eval_gpu"); + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + slice_gpu(in, out, start_indices_, strides_, stream()); +} + +void Squeeze::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Squeeze::eval_gpu"); + eval(inputs, out); +} + +void StopGradient::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("StopGradient::eval_gpu"); + eval(inputs, out); +} + +void Transpose::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Transpose::eval_gpu"); + eval(inputs, out); +} + +void Unflatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Unflatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void View::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("View::eval_gpu"); + auto& in = inputs[0]; + auto ibytes = size_of(in.dtype()); + auto obytes = size_of(out.dtype()); + // Conditions for buffer copying (disjunction): + // - type size is the same + // - type size is smaller and the last axis is contiguous + // - the entire array is row contiguous + if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || + in.flags().row_contiguous) { + auto strides = in.strides(); + for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { + strides[i] *= ibytes; + strides[i] /= obytes; + } + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * ibytes / obytes); + } else { + auto tmp = array(in.shape(), in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc(tmp.nbytes())); + copy_gpu_inplace(in, tmp, CopyType::General, stream()); + + auto flags = out.flags(); + flags.contiguous = true; + flags.row_contiguous = true; + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/gpu/slicing.cpp b/mlx/backend/gpu/slicing.cpp new file mode 100644 index 000000000..fde2a01cd --- /dev/null +++ b/mlx/backend/gpu/slicing.cpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides, + const Stream& s) { + slice(in, out, start_indices, strides); +} + +void pad_gpu( + const array& in, + const array& val, + array& out, + const std::vector& axes, + const Shape& low_pad_size, + const Stream& s) { + // Fill output with val + fill_gpu(val, out, s); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes.size(); i++) { + auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; + data_offset += out.strides()[ax] * low_pad_size[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/gpu/slicing.h similarity index 100% rename from mlx/backend/metal/slicing.h rename to mlx/backend/gpu/slicing.h diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 332c560f8..d0c872451 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -93,6 +93,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0a69dd261..5d8bd90d5 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,7 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/resident.h" #include "mlx/memory.h" diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index f80f8c3e4..c3c67e4d5 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -90,7 +90,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(a.dtype()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); @@ -137,13 +137,20 @@ void binary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 154273233..db20f938c 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -64,6 +64,7 @@ inline void build_kernel( cnt++); } + std::string idx_type = use_big_index ? "int64_t" : "uint"; if (add_indices) { os += fmt::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); @@ -83,6 +84,9 @@ inline void build_kernel( " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); + } else { + os += fmt::format( + " constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++); } if (dynamic_dims) { os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); @@ -92,13 +96,14 @@ inline void build_kernel( os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; - std::string idx_type = use_big_index ? "int64_t" : "uint"; + os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; + os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; + } else if (contiguous) { + os += " uint index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { - os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += fmt::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); @@ -110,6 +115,9 @@ inline void build_kernel( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } + if (work_per_thread > 1 && contiguous) { + os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; + } // Read constant / contiguous inputs in tmps std::vector nc_inputs; @@ -193,7 +201,7 @@ inline void build_kernel( } // Open per-thread loop - if (work_per_thread > 1) { + if (work_per_thread > 1 && !contiguous) { os += " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } @@ -272,6 +280,7 @@ void Compiled::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { + int work_per_thread = get_work_per_thread(outputs_[0].dtype()); std::string kernel = metal::utils(); concatenate( kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); @@ -284,7 +293,9 @@ void Compiled::eval_gpu( constant_ids_, /* contiguous = */ true, /* ndim = */ 0, - /* dynamic_dims = */ false); + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -295,7 +306,8 @@ void Compiled::eval_gpu( /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, - /* use_big_index = */ true); + /* use_big_index = */ true, + /* work_per_thread = */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -468,6 +480,13 @@ void Compiled::eval_gpu( if (!contiguous) { compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); + } else { + auto size = outputs[0].data_size(); + if (large) { + compute_encoder.set_bytes(size, cnt++); + } else { + compute_encoder.set_bytes(size, cnt++); + } } // Put the number of dims in if it is dynamic @@ -477,12 +496,13 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - size_t nthreads = outputs[0].data_size(); + int work_per_thread = get_work_per_thread(outputs[0].dtype()); + size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - MTL::Size grid_dims = large - ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + ? get_2d_grid_dims( + outputs[0].shape(), outputs[0].strides(), work_per_thread) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9075ea4c5..ae31a6cff 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -5,7 +5,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 3399201de..8dfe15c11 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -1,35 +1,15 @@ // Copyright © 2023-2024 Apple Inc. -#include - +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - bool donated = set_copy_output_data(in, out, ctype); - if (donated && in.dtype() == out.dtype()) { - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - return; - } - if (ctype == CopyType::GeneralGeneral) { - ctype = CopyType::General; - } - copy_gpu_inplace(in, out, ctype, s); -} - -void copy_gpu(const array& in, array& out, CopyType ctype) { - copy_gpu(in, out, ctype, out.primitive().stream()); -} - void copy_gpu_inplace( const array& in, array& out, @@ -104,6 +84,8 @@ void copy_gpu_inplace( "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); } } + } else { + work_per_thread = get_work_per_thread(in.dtype()); } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -165,39 +147,23 @@ void copy_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } -void copy_gpu_inplace( - const array& in, - array& out, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); -} - -void copy_gpu_inplace( - const array& in, - array& out, - const Strides& i_strides, - int64_t i_offset, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); -} - void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; @@ -214,14 +180,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); + int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 8a672289a..ea4f258cc 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,6 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 95aeb1cc9..ebc3cc77f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,20 +1,20 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include -#include - #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" +namespace fs = std::filesystem; + namespace mlx::core::metal { namespace { @@ -66,8 +66,8 @@ MTL::Library* try_load_bundle( if (bundle != nullptr) { std::string resource_path = std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + - lib_name + ".metallib" auto [lib, error] = - load_library_from_path(device, resource_path.c_str()); + lib_name + ".metallib"; + auto [lib, error] = load_library_from_path(device, resource_path.c_str()); if (lib) { return lib; } @@ -79,12 +79,18 @@ MTL::Library* try_load_bundle( // Firstly, search for the metallib in the same path as this binary std::pair load_colocated_library( MTL::Device* device, - const std::string& lib_name) { - std::string lib_path = get_colocated_mtllib_path(lib_name); - if (lib_path.size() != 0) { - return load_library_from_path(device, lib_path.c_str()); + const std::string& relative_path) { + std::string binary_dir = get_binary_directory(); + if (binary_dir.size() == 0) { + return {nullptr, nullptr}; } - return {nullptr, nullptr}; + + auto path = fs::path(binary_dir) / relative_path; + if (!path.has_extension()) { + path.replace_extension(".metallib"); + } + + return load_library_from_path(device, path.c_str()); } std::pair load_swiftpm_library( @@ -99,7 +105,7 @@ std::pair load_swiftpm_library( auto bundles = NS::Bundle::allBundles(); for (int i = 0, c = (int)bundles->count(); i < c; i++) { auto bundle = reinterpret_cast(bundles->object(i)); - library = try_load_bundle(device, bundle->resourceURL()); + library = try_load_bundle(device, bundle->resourceURL(), lib_name); if (library != nullptr) { return {library, nullptr}; } @@ -109,33 +115,34 @@ std::pair load_swiftpm_library( } MTL::Library* load_default_library(MTL::Device* device) { - NS::Error *error1, *error2, *error3; + NS::Error* error[4]; MTL::Library* lib; // First try the colocated mlx.metallib - std::tie(lib, error1) = load_colocated_library(device, "mlx"); + std::tie(lib, error[0]) = load_colocated_library(device, "mlx"); + if (lib) { + return lib; + } + + std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx"); if (lib) { return lib; } // Then try default.metallib in a SwiftPM bundle if we have one - std::tie(lib, error2) = load_swiftpm_library(device, "default"); + std::tie(lib, error[2]) = load_swiftpm_library(device, "default"); if (lib) { return lib; } // Finally try default_mtllib_path - std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path); + std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path); if (!lib) { std::ostringstream msg; msg << "Failed to load the default metallib. "; - if (error1 != nullptr) { - msg << error1->localizedDescription()->utf8String() << " "; - } - if (error2 != nullptr) { - msg << error2->localizedDescription()->utf8String() << " "; - } - if (error3 != nullptr) { - msg << error3->localizedDescription()->utf8String() << " "; + for (int i = 0; i < 4; i++) { + if (error[i] != nullptr) { + msg << error[i]->localizedDescription()->utf8String() << " "; + } } throw std::runtime_error(msg.str()); } @@ -156,6 +163,7 @@ MTL::Library* load_library( << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } + return lib; } // We have been given a path so try to load from lib_path / lib_name.metallib @@ -168,6 +176,7 @@ MTL::Library* load_library( << "> with error " << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } + return lib; } // Try to load the colocated library @@ -188,8 +197,8 @@ MTL::Library* load_library( std::ostringstream msg; msg << "Failed to load the metallib " << lib_name << ".metallib. " - << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name) - << ">"; + << "We attempted to load it from <" << get_binary_directory() << "/" + << lib_name << ".metallib" << ">"; #ifdef SWIFTPM_BUNDLE msg << " and from the Swift PM bundle."; #endif @@ -760,42 +769,4 @@ std::unique_ptr> new_scoped_memory_pool() { NS::AutoreleasePool::alloc()->init(), dtor); } -void new_stream(Stream stream) { - if (stream.device == mlx::core::Device::gpu) { - device(stream.device).new_queue(stream.index); - } -} - -const std::unordered_map>& -device_info() { - auto init_device_info = []() - -> std::unordered_map> { - auto pool = new_scoped_memory_pool(); - auto raw_device = device(default_device()).mtl_device(); - auto name = std::string(raw_device->name()->utf8String()); - auto arch = std::string(raw_device->architecture()->name()->utf8String()); - - size_t memsize = 0; - size_t length = sizeof(memsize); - sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); - - size_t rsrc_limit = 0; - sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); - if (rsrc_limit == 0) { - rsrc_limit = 499000; - } - - return { - {"device_name", name}, - {"architecture", arch}, - {"max_buffer_length", raw_device->maxBufferLength()}, - {"max_recommended_working_set_size", - raw_device->recommendedMaxWorkingSetSize()}, - {"memory_size", memsize}, - {"resource_limit", rsrc_limit}}; - }; - static auto device_info_ = init_device_info(); - return device_info_; -} - } // namespace mlx::core::metal diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bb0e93147..26c9a0a28 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -21,18 +21,14 @@ namespace mlx::core::metal { // Note, this function must be left inline in a header so that it is not // dynamically linked. -inline std::string get_colocated_mtllib_path(const std::string& lib_name) { +inline std::string get_binary_directory() { Dl_info info; - std::string mtllib_path; - std::string lib_ext = lib_name + ".metallib"; - - int success = dladdr((void*)get_colocated_mtllib_path, &info); + std::string directory; + int success = dladdr((void*)get_binary_directory, &info); if (success) { - auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; - mtllib_path = mtllib.c_str(); + directory = fs::path(info.dli_fname).remove_filename().c_str(); } - - return mtllib_path; + return directory; } using MTLFCList = @@ -270,4 +266,6 @@ class Device { Device& device(mlx::core::Device); +std::unique_ptr> new_scoped_memory_pool(); + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 82e8fff7d..a800d2e0f 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -4,7 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp new file mode 100644 index 000000000..49783200a --- /dev/null +++ b/mlx/backend/metal/eval.cpp @@ -0,0 +1,102 @@ +// Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream stream) { + if (stream.device == mlx::core::Device::gpu) { + metal::device(stream.device).new_queue(stream.index); + } +} + +inline void check_error(MTL::CommandBuffer* cbuf) { + if (cbuf->status() == MTL::CommandBufferStatusError) { + std::ostringstream msg; + msg << "[METAL] Command buffer execution failed: " + << cbuf->error()->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } +} + +void eval(array& arr) { + auto pool = metal::new_scoped_memory_pool(); + auto s = arr.primitive().stream(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + + debug_set_primitive_buffer_label(command_buffer, arr.primitive()); + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + + if (d.command_buffer_needs_commit(s.index)) { + d.end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + scheduler::notify_task_completion(s); + check_error(cbuf); + }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + check_error(cbuf); + }); + } +} + +void finalize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + d.end_encoding(s.index); + cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); +} + +void synchronize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + cb->retain(); + d.end_encoding(s.index); + d.commit_command_buffer(s.index); + cb->waitUntilCompleted(); + check_error(cb); + cb->release(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 246d6bcc5..eb7f1b58a 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -2,7 +2,6 @@ #include "mlx/event.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" namespace mlx::core { diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index e784d34ae..5abdf7309 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" #include "mlx/utils.h" @@ -139,7 +138,7 @@ void Fence::update(Stream stream, const array& x) { compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_bytes(nthreads, 1); - compute_encoder.dispatch_threadgroups(group_dims, grid_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Barrier on previous kernels compute_encoder.barrier(); diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 153c62c02..011eb7ebb 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -7,10 +7,10 @@ #include "mlx/3rdparty/pocketfft.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/binary.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index a7dfc5f17..65a877151 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -1,11 +1,9 @@ // Copyright © 2024 Apple Inc. -#include - -#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/hadamard.h" +#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" @@ -15,7 +13,6 @@ namespace mlx::core { constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; -constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB std::string gen_hadamard_codelet(int m) { // Generate a O(m^2) hadamard codelet for a given M @@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) { return source.str(); } -void Hadamard::eval_gpu(const std::vector& inputs, array& out) { - auto& s = stream(); +void hadamard_mn_contiguous( + const array& x, + array& y, + int m, + int n1, + int n2, + float scale, + metal::Device& d, + const Stream& s) { + int n = n1 * n2; + int read_width_n1 = n1 == 2 ? 2 : 4; + int read_width_n2 = n2 == 2 ? 2 : 4; + int read_width_m = (n == 2 || m == 28) ? 2 : 4; + int max_radix_1 = std::min(n1, 16); + int max_radix_2 = std::min(n2, 16); + float scale_n1 = 1.0; + float scale_n2 = (m == 1) ? scale : 1.0; + float scale_m = scale; - auto& in = inputs[0]; + // n2 is a row contiguous power of 2 hadamard transform + MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1); + MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1); - std::vector copies; - // Only support the last axis for now - int axis = in.ndim() - 1; - auto check_input = [&copies, &s](const array& x) { - // TODO(alexbarron) pass strides to kernel to relax this constraint - bool no_copy = x.flags().row_contiguous; - if (no_copy) { - return x; - } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + // n1 is a strided power of 2 hadamard transform with stride n2 + MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1); + MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2); + + // m is a strided hadamard transform with stride n = n1 * n2 + MTL::Size group_dims_m( + std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1); + MTL::Size grid_dims_m( + group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1); + + // Make the kernel + std::string kname; + kname.reserve(32); + concatenate(kname, "hadamard_", n * m, "_", type_to_name(x)); + auto lib = d.get_library(kname, [&]() { + std::string kernel; + concatenate( + kernel, + metal::utils(), + gen_hadamard_codelet(m), + metal::hadamard(), + get_template_definition( + "n2" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n2, + max_radix_2, + read_width_n2)); + if (n1 > 1) { + kernel += get_template_definition( + "n1" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n1, + max_radix_1, + read_width_n1, + n2); } - }; - const array& in_contiguous = check_input(in); - - if (in_contiguous.is_donatable()) { - out.copy_shared_buffer(in_contiguous); - } else { - out.set_data(allocator::malloc(out.nbytes())); - } - - int n, m; - std::tie(n, m) = decompose_hadamard(in.shape(axis)); - - if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { - throw std::invalid_argument( - "[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI"); - } - - int max_radix = std::min(n, 16); - // Use read_width 2 for m = 28 to avoid register spilling - int read_width = (n == 2 || m == 28) ? 2 : 4; - - std::ostringstream kname; - kname << "hadamard_" << n * m << "_" << type_to_name(out); - auto kernel_name = kname.str(); - auto& d = metal::device(s.device); - const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - auto codelet = gen_hadamard_codelet(m); - kernel_source << metal::utils() << codelet << metal::hadamard(); - kernel_source << get_template_definition( - "n" + kernel_name, - "hadamard_n", - get_type_string(in.dtype()), - n, - max_radix, - read_width); - kernel_source << get_template_definition( - "m" + kernel_name, - "hadamard_m", - get_type_string(in.dtype()), - n, - m, - read_width); - return kernel_source.str(); + if (m > 1) { + kernel += get_template_definition( + "m" + kname, + "hadamard_m", + get_type_string(x.dtype()), + n, + m, + read_width_m); + } + return kernel; }); - int batch_size = in.size() / n; - int threads_per = n / max_radix; - - auto& compute_encoder = d.get_command_encoder(s.index); - - auto launch_hadamard = [&](const array& in, - array& out, - const std::string& kernel_name, - float scale) { - auto kernel = d.get_kernel(kernel_name, lib); - assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); - + // Launch the strided transform for n1 + if (n1 > 1) { + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n1" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(scale, 2); - - MTL::Size group_dims = MTL::Size(1, threads_per, 1); - MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - }; - - if (m > 1) { - // When m is greater than 1, we decompose the - // computation into two uploads to the GPU: - // - // e.g. len(x) = 12*4 = 48, m = 12, n = 4 - // - // y = h48 @ x - // - // Upload 1: - // tmp = a.reshape(12, 4) @ h4 - // - // Upload 2: - // y = h12 @ tmp - array temp(in.shape(), in.dtype(), nullptr, {}); - temp.set_data(allocator::malloc(temp.nbytes())); - copies.push_back(temp); - - launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); - - // Metal sometimes reports 256 max threads per group for hadamard_m kernel - threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); - batch_size = in.size() / m / read_width / threads_per; - launch_hadamard(temp, out, "m" + kernel_name, scale_); - } else { - launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n1, 2); + compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1); } - d.add_temporaries(std::move(copies), s.index); + // Launch the transform for n2 + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n2" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(n1 > 1 ? y : x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n2, 2); + compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2); + + // Launch the strided transform for m + if (m > 1) { + auto kernel = d.get_kernel("m" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(y, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_m, 2); + compute_encoder.dispatch_threads(grid_dims_m, group_dims_m); + } +} + +void Hadamard::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + // Split the hadamard transform so that all of them work on vectors smaller + // than 8192 elements. + // + // We decompose it in the following way: + // + // n = m * n1 * n2 = m * 2^k1 * 2^k2 + // + // where m is in (1, 12, 20, 28) and n1 and n2 <= 8192 + auto [n, m] = decompose_hadamard(in.shape().back()); + int n1 = 1, n2 = n; + if (n > 8192) { + for (n2 = 2; n2 * n2 < n; n2 *= 2) { + } + n1 = n / n2; + } + + if (in.flags().row_contiguous) { + if (in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s); + } else { + copy_gpu(in, out, CopyType::General, s); + hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a263051..cccfd908a 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 91a02c818..ffc33ad82 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -9,64 +9,85 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 3ef8e6269..1d555fefa 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) +instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 8f961c2cf..4aaf2b4da 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -130,6 +130,24 @@ struct LogAddExp { ? maxval : (maxval + log1p(metal::exp(minval - maxval))); }; + + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || + metal::isnan(y.imag)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr float inf = metal::numeric_limits::infinity(); + complex64_t maxval = x > y ? x : y; + complex64_t minval = x < y ? x : y; + if (minval.real == -inf || maxval.real == inf) + return maxval; + float m = metal::exp(minval.real - maxval.real); + complex64_t dexp{ + m * metal::cos(minval.imag - maxval.imag), + m * metal::sin(minval.imag - maxval.imag), + }; + return maxval + log1p(dexp); + } }; struct Maximum { diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 8f6b3392d..e261d33c4 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -12,82 +12,103 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } template diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index b1367cf4f..2469d1f3d 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,39 +1,53 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/fft/readwrite.h b/mlx/backend/metal/kernels/fft/readwrite.h index ab699e136..f6724820d 100644 --- a/mlx/backend/metal/kernels/fft/readwrite.h +++ b/mlx/backend/metal/kernels/fft/readwrite.h @@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so read/write performance is important. Where possible, we read 128 bits sequentially in each thread, -coalesced with accesses from adajcent threads for optimal performance. +coalesced with accesses from adjacent threads for optimal performance. We implement specialized reading/writing for: - FFT diff --git a/mlx/backend/metal/kernels/hadamard.h b/mlx/backend/metal/kernels/hadamard.h index 93e2fb8a8..9f2311c10 100644 --- a/mlx/backend/metal/kernels/hadamard.h +++ b/mlx/backend/metal/kernels/hadamard.h @@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) { } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -46,18 +46,25 @@ template constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; // Read values from device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } @@ -113,12 +120,20 @@ template } // Write values to device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b2b0d8d8f..ba4fb2426 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl( auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl( // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 8fcd7f61b..f38f8757e 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) -instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index c4c0f6456..8258e9c14 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -56,9 +56,9 @@ template const int head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; @@ -213,9 +213,9 @@ template const int block_idx = tid.z; const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; const int kv_head_idx = head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; @@ -358,8 +358,8 @@ template // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int n_heads = tpg.x; - const int q_offset = n_heads * q_seq_idx + head_idx; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 2e27ea06f..34d5bf58a 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -95,7 +95,7 @@ template < Q += tidl.z * params->Q_strides[0] + // Batch tidl.y * params->Q_strides[1] + // Head - tidl.x * BQ * params->Q_strides[2]; // Seqeunce + tidl.x * BQ * params->Q_strides[2]; // Sequence ulong kv_head_idx = int(tid.y) / params->gqa_factor; K += tidl.z * params->K_strides[0] + // Batch @@ -106,7 +106,7 @@ template < O += tidl.z * params->O_strides[0] + // Batch tidl.y * params->O_strides[1] + // Head - tidl.x * BQ * params->O_strides[2]; // Seqeunce + tidl.x * BQ * params->O_strides[2]; // Sequence if (has_mask) { mask += tidl.z * mask_params->M_strides[0] + // Batch diff --git a/mlx/backend/metal/kernels/steel/attn/loader.h b/mlx/backend/metal/kernels/steel/attn/loader.h index 2849c00f1..7ec798146 100644 --- a/mlx/backend/metal/kernels/steel/attn/loader.h +++ b/mlx/backend/metal/kernels/steel/attn/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); @@ -240,7 +240,7 @@ struct BlockLoaderT { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index e4b662cd3..8253638f1 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general( // Store results to device memory { - // Adjust for simdgroup and thread locatio + // Adjust for simdgroup and thread location int offset_m = c_row + mma_op.sm; int offset_n = c_col + mma_op.sn; C += offset_n; diff --git a/mlx/backend/metal/kernels/steel/gemm/loader.h b/mlx/backend/metal/kernels/steel/gemm/loader.h index 3f084d8ec..d421b2d1f 100644 --- a/mlx/backend/metal/kernels/steel/gemm/loader.h +++ b/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 4b3adcc80..5251dc7e9 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,25 +1,32 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 69828599f..b5eaab2e9 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,21 +1,28 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + out[index + i] = Op()(in[index + i]); + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } template < diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index d34c5a7ec..afced7eb7 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Log, complex64, complex64_t) +instantiate_unary_all_same(Log1p, complex64, complex64_t) instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index b31cd20d6..c30d186b8 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -15,6 +15,14 @@ typedef half float16_t; +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// @@ -328,6 +336,23 @@ inline bfloat16_t log1p(bfloat16_t x) { return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } +inline complex64_t log1p(complex64_t in) { + float x = in.real; + float y = in.imag; + float zabs = metal::precise::sqrt(x * x + y * y); + float theta = metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); + return {metal::log(z0), theta}; + } +} + /////////////////////////////////////////////////////////////////////////////// // SIMD shuffle ops /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index 4901190e1..e53bc58d9 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index f55d20c9f..71221f8d9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -7,7 +7,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index a9a1bc4f6..888207322 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,11 +1,11 @@ // Copyright © 2023-2024 Apple Inc. #include +#include + #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" -#include "mlx/scheduler.h" -#include "mlx/utils.h" namespace mlx::core::metal { @@ -13,85 +13,6 @@ bool is_available() { return true; } -inline void check_error(MTL::CommandBuffer* cbuf) { - if (cbuf->status() == MTL::CommandBufferStatusError) { - std::ostringstream msg; - msg << "[METAL] Command buffer execution failed: " - << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); - } -} - -void eval(array& arr) { - auto pool = new_scoped_memory_pool(); - auto s = arr.primitive().stream(); - auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - - auto outputs = arr.outputs(); - { - // If the array is a tracer hold a reference - // to its inputs so they don't get donated - std::vector inputs; - if (arr.is_tracer()) { - inputs = arr.inputs(); - } - - debug_set_primitive_buffer_label(command_buffer, arr.primitive()); - arr.primitive().eval_gpu(arr.inputs(), outputs); - } - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - - if (d.command_buffer_needs_commit(s.index)) { - d.end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); - }); - } -} - -void finalize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - d.end_encoding(s.index); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); -} - -void synchronize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - cb->retain(); - d.end_encoding(s.index); - d.commit_command_buffer(s.index); - cb->waitUntilCompleted(); - check_error(cb); - cb->release(); -} - void start_capture(std::string path, id object) { auto pool = new_scoped_memory_pool(); @@ -128,4 +49,36 @@ void stop_capture() { manager->stopCapture(); } +const std::unordered_map>& +device_info() { + auto init_device_info = []() + -> std::unordered_map> { + auto pool = new_scoped_memory_pool(); + auto raw_device = device(default_device()).mtl_device(); + auto name = std::string(raw_device->name()->utf8String()); + auto arch = std::string(raw_device->architecture()->name()->utf8String()); + + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + + size_t rsrc_limit = 0; + sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); + if (rsrc_limit == 0) { + rsrc_limit = 499000; + } + + return { + {"device_name", name}, + {"architecture", arch}, + {"max_buffer_length", raw_device->maxBufferLength()}, + {"max_recommended_working_set_size", + raw_device->recommendedMaxWorkingSetSize()}, + {"memory_size", memsize}, + {"resource_limit", rsrc_limit}}; + }; + static auto device_info_ = init_device_info(); + return device_info_; +} + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d162007d1..af2995b63 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -2,11 +2,10 @@ #pragma once +#include #include #include -#include "mlx/array.h" - namespace mlx::core::metal { /* Check if the Metal backend is available. */ diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp new file mode 100644 index 000000000..b6142b280 --- /dev/null +++ b/mlx/backend/metal/no_metal.cpp @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/metal/metal.h" + +namespace mlx::core::metal { + +bool is_available() { + return false; +} + +void start_capture(std::string) {} +void stop_capture() {} + +const std::unordered_map>& +device_info() { + throw std::runtime_error( + "[metal::device_info] Cannot get device info without metal backend"); +}; + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c1d993d2a..21142183e 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6946ffb9e..860e9ddd7 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -7,10 +7,10 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - static array compute_dynamic_offset( const array& indices, const Strides& strides, @@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } -void AsType::eval_gpu(const std::vector& inputs, array& out) { - CopyType ctype = - inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; - copy_gpu(inputs[0], out, ctype); -} - -void AsStrided::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Broadcast::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Concatenate::eval_gpu(const std::vector& inputs, array& out) { - concatenate_gpu(inputs, out, axis_, stream()); -} - -void Contiguous::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; - if (in.buffer_size() <= out.nbytes() + extra_bytes && - (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous))) { - out.copy_shared_buffer(in); - } else { - copy_gpu(in, out, CopyType::General); - } -} - -void Copy::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void CustomTransforms::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Depends::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Full::eval_gpu(const std::vector& inputs, array& out) { - auto in = inputs[0]; - CopyType ctype; - if (in.data_size() == 1) { - ctype = CopyType::Scalar; - } else if (in.flags().contiguous) { - ctype = CopyType::Vector; - } else { - ctype = CopyType::General; - } - copy_gpu(in, out, ctype); -} - -void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Flatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Unflatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - void Load::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } -void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Pad::eval_gpu(const std::vector& inputs, array& out) { - // Inputs must be base input array and scalar val array - assert(inputs.size() == 2); - auto& in = inputs[0]; - auto& val = inputs[1]; - - // Padding value must be a scalar - assert(val.size() == 1); - - // Padding value, input and output must be of the same type - assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); - - pad_gpu(in, val, out, axes_, low_pad_size_, stream()); -} - void RandomBits::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); @@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void Reshape::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Split::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Slice::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - slice_gpu(in, out, start_indices_, strides_, stream()); -} - void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(nullptr); @@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { /* const Stream& s = */ stream()); } -void Squeeze::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void StopGradient::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Transpose::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -537,35 +390,4 @@ void LUF::eval_gpu( throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } -void View::eval_gpu(const std::vector& inputs, array& out) { - auto& in = inputs[0]; - auto ibytes = size_of(in.dtype()); - auto obytes = size_of(out.dtype()); - // Conditions for buffer copying (disjunction): - // - type size is the same - // - type size is smaller and the last axis is contiguous - // - the entire array is row contiguous - if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || - in.flags().row_contiguous) { - auto strides = in.strides(); - for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { - strides[i] *= ibytes; - strides[i] /= obytes; - } - out.copy_shared_buffer( - in, strides, in.flags(), in.data_size() * ibytes / obytes); - } else { - auto tmp = array(in.shape(), in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc(tmp.nbytes())); - copy_gpu_inplace(in, tmp, CopyType::General, stream()); - - auto flags = out.flags(); - flags.contiguous = true; - flags.row_contiguous = true; - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); - } -} - } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 6f5807543..11a2355cc 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -4,7 +4,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index c5650bdd7..8cb55ba58 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 0a9e1b861..798824c2f 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/resident.h" -#include "mlx/backend/metal/metal_impl.h" namespace mlx::core::metal { diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 060758333..d8201afe6 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 845962d01..3c7b7ff19 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/steel/attn/params.h" @@ -154,9 +154,9 @@ void sdpa_vector( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(1024, 1, 1); @@ -199,11 +199,10 @@ void sdpa_vector( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(head_stride, 15); @@ -238,9 +237,10 @@ void sdpa_vector_2pass( int N = k.shape(2); int blocks = 32; int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(8 * 32, 1, 1); MTL::Size grid_dims(B, q.shape(2), blocks); @@ -302,11 +302,10 @@ void sdpa_vector_2pass( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 13 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(head_stride, 17); @@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu( } }; - // Checks if arr is row contiguous or the sequence and head dimension are - // transposed - auto is_contiguous_or_head_seq_transposed = [](const array& arr) { - if (arr.flags().row_contiguous) { - return true; - } - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) && - (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); - }; - // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(-1) == 1; @@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { - const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); - const auto& k = copy_unless(is_matrix_contiguous, k_pre); - const auto& v = copy_unless(is_matrix_contiguous, v_pre); + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + // If either the batch or head dimension is a singleton, the other can + // be transposed with the sequence dimension + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + // keys and values should be copied if: + // - the last dimension is not contiguous + // - the batch and head dim are not contiguous + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& v = copy_unless(kv_copy_unless, v_pre); // Donate the query if possible - if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && - q.size() == o.size()) { + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { - if (o.shape(2) == 1) { - o.set_data(allocator::malloc(o.nbytes())); - } else { - auto strides = o.strides(); - strides[2] = o.shape(1) * o.shape(3); - strides[1] = o.shape(3); - auto flags = q.flags(); - flags.row_contiguous = q.shape(1) == 1; - o.set_data( - allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags); - } + o.set_data(allocator::malloc(o.nbytes())); } - auto mask = - inputs.size() > 3 ? std::optional{inputs[3]} : std::nullopt; + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + + auto mask = inputs.size() > 3 + ? std::optional{copy_unless(mask_copy_unless, inputs[3])} + : std::nullopt; // We route to the 2 pass fused attention if // - The device is large and the sequence length long diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index b1800fea9..3c4051105 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 6ab08a108..3e1a8b541 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -2,21 +2,12 @@ #include -#include "mlx/backend/common/slicing.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" namespace mlx::core { -void slice_gpu( - const array& in, - array& out, - const Shape& start_indices, - const Shape& strides, - const Stream& s) { - slice(in, out, start_indices, strides); -} - void concatenate_gpu( const std::vector& inputs, array& out, @@ -48,30 +39,4 @@ void concatenate_gpu( } } -void pad_gpu( - const array& in, - const array& val, - array& out, - const std::vector& axes, - const Shape& low_pad_size, - const Stream& s) { - // Fill output with val - fill_gpu(val, out, s); - - // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes.size(); i++) { - auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; - data_offset += out.strides()[ax] * low_pad_size[i]; - } - - // Extract slice from output where input will be pasted - array out_slice(in.shape(), out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out, out.strides(), out.flags(), out_slice.size(), data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); -} - } // namespace mlx::core diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 224721a50..59662b05d 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 543dfd180..3c84022f2 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -2,7 +2,7 @@ #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 36bfd3e2b..0b821151e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(b.dtype()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -106,13 +106,19 @@ void ternary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index be43c41c2..368e693a9 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -34,18 +34,19 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - size_t nthreads = contig ? in.data_size() : in.size(); bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } - int work_per_thread = !contig && large ? 4 : 1; + int work_per_thread; std::string kernel_name; if (contig) { + work_per_thread = get_work_per_thread(in.dtype()); kernel_name = (large ? "v2" : "v"); } else { + work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); if (large) { kernel_name += "large"; @@ -75,12 +76,20 @@ void unary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { + size_t nthreads = ceildiv(in.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 079d15f17..f9245a6d6 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} + +inline size_t ceildiv(size_t n, size_t m) { + return (n + m - 1) / m; +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index e1524ec63..2e6960829 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -1,6 +1,7 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) diff --git a/mlx/backend/no_cpu/available.cpp b/mlx/backend/no_cpu/available.cpp new file mode 100644 index 000000000..04c1bac8e --- /dev/null +++ b/mlx/backend/no_cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return false; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/no_cpu/compiled.cpp b/mlx/backend/no_cpu/compiled.cpp index c1c42c735..2eeddab47 100644 --- a/mlx/backend/no_cpu/compiled.cpp +++ b/mlx/backend/no_cpu/compiled.cpp @@ -18,7 +18,7 @@ void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error( - "[Compiled::eval_cpu] CPU compialtion not supported on the platform."); + "[Compiled::eval_cpu] CPU compilation not supported on the platform."); } } // namespace mlx::core diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_gpu/CMakeLists.txt similarity index 82% rename from mlx/backend/no_metal/CMakeLists.txt rename to mlx/backend/no_gpu/CMakeLists.txt index 962ceecb7..78e15ac69 100644 --- a/mlx/backend/no_metal/CMakeLists.txt +++ b/mlx/backend/no_gpu/CMakeLists.txt @@ -3,5 +3,5 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp similarity index 96% rename from mlx/backend/no_metal/allocator.cpp rename to mlx/backend/no_gpu/allocator.cpp index a8b260b6b..320d1a267 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -6,9 +6,9 @@ #include "mlx/allocator.h" #ifdef __APPLE__ -#include "mlx/backend/no_metal/apple_memory.h" +#include "mlx/backend/no_gpu/apple_memory.h" #elif defined(__linux__) -#include "mlx/backend/no_metal/linux_memory.h" +#include "mlx/backend/no_gpu/linux_memory.h" #else size_t get_memory_size() { return 0; diff --git a/mlx/backend/no_metal/apple_memory.h b/mlx/backend/no_gpu/apple_memory.h similarity index 100% rename from mlx/backend/no_metal/apple_memory.h rename to mlx/backend/no_gpu/apple_memory.h diff --git a/mlx/backend/no_gpu/eval.cpp b/mlx/backend/no_gpu/eval.cpp new file mode 100644 index 000000000..8bff86a98 --- /dev/null +++ b/mlx/backend/no_gpu/eval.cpp @@ -0,0 +1,28 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" + +namespace mlx::core::gpu { + +bool is_available() { + return false; +} + +void new_stream(Stream) {} + +void eval(array&) { + throw std::runtime_error("[gpu::eval] GPU backend is not available"); +} + +void finalize(Stream) { + throw std::runtime_error("[gpu::finalize] GPU backend is not available"); +} + +void synchronize(Stream) { + throw std::runtime_error("[gpu::synchronize] GPU backend is not available"); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_gpu/event.cpp similarity index 100% rename from mlx/backend/no_metal/event.cpp rename to mlx/backend/no_gpu/event.cpp diff --git a/mlx/backend/no_metal/fence.cpp b/mlx/backend/no_gpu/fence.cpp similarity index 100% rename from mlx/backend/no_metal/fence.cpp rename to mlx/backend/no_gpu/fence.cpp diff --git a/mlx/backend/no_metal/linux_memory.h b/mlx/backend/no_gpu/linux_memory.h similarity index 100% rename from mlx/backend/no_metal/linux_memory.h rename to mlx/backend/no_gpu/linux_memory.h diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp similarity index 100% rename from mlx/backend/no_metal/primitives.cpp rename to mlx/backend/no_gpu/primitives.cpp diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp deleted file mode 100644 index ef9af8800..000000000 --- a/mlx/backend/no_metal/metal.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" -namespace mlx::core::metal { - -bool is_available() { - return false; -} - -void new_stream(Stream) {} - -std::unique_ptr> new_scoped_memory_pool() { - return nullptr; -} - -void eval(array&) { - throw std::runtime_error( - "[metal::eval] Cannot eval on GPU without metal backend"); -} - -void finalize(Stream) { - throw std::runtime_error( - "[metal::finalize] Cannot finalize GPU without metal backend"); -} - -void synchronize(Stream) { - throw std::runtime_error( - "[metal::synchronize] Cannot synchronize GPU without metal backend"); -} - -void start_capture(std::string) {} -void stop_capture() {} - -const std::unordered_map>& -device_info() { - throw std::runtime_error( - "[metal::device_info] Cannot get device info without metal backend"); -}; - -} // namespace mlx::core::metal diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 7ff5c8f9e..2baeb6fcf 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) { parent.first.inputs()[parent.second] = dst; pairs.push_back(parent); } + + // If src is a parent of dst, remove it from dst's parents + for (auto it = pairs.begin(); it != pairs.end();) { + if (it->first.id() == src.id()) { + it = pairs.erase(it); + } else { + it++; + } + } // Remove the source from the map to avoid fusing with it again parents_map.erase(src_parents); } diff --git a/mlx/device.cpp b/mlx/device.cpp index e635782e2..ec17a509a 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -1,23 +1,28 @@ // Copyright © 2023 Apple Inc. +#include + +#include "mlx/backend/cpu/available.h" +#include "mlx/backend/gpu/available.h" #include "mlx/device.h" -#include "mlx/backend/metal/metal.h" namespace mlx::core { -static Device default_device_{ - metal::is_available() ? Device::gpu : Device::cpu}; +Device& mutable_default_device() { + static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; + return default_device; +} const Device& default_device() { - return default_device_; + return mutable_default_device(); } void set_default_device(const Device& d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } - default_device_ = d; + mutable_default_device() = d; } bool operator==(const Device& lhs, const Device& rhs) { @@ -28,4 +33,15 @@ bool operator!=(const Device& lhs, const Device& rhs) { return !(lhs == rhs); } +bool is_available(const Device& d) { + switch (d.type) { + case Device::cpu: + return cpu::is_available(); + case Device::gpu: + return gpu::is_available(); + } + // appease compiler + return false; +} + } // namespace mlx::core diff --git a/mlx/device.h b/mlx/device.h index a11e40e9d..80c624c1c 100644 --- a/mlx/device.h +++ b/mlx/device.h @@ -26,4 +26,6 @@ void set_default_device(const Device& d); bool operator==(const Device& lhs, const Device& rhs); bool operator!=(const Device& lhs, const Device& rhs); +bool is_available(const Device& d); + } // namespace mlx::core diff --git a/mlx/export.cpp b/mlx/export.cpp index effc7a0c1..c9139e156 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -470,6 +470,9 @@ bool FunctionTable::match( if (x.dtype() != y.dtype()) { return false; } + if (x.ndim() != y.ndim()) { + return false; + } if (!shapeless && x.shape() != y.shape()) { return false; } diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 961c1226c..8d06c7c54 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -186,6 +186,7 @@ array irfftn( StreamOrDevice s /* = {} */) { return fft_impl(a, axes, true, true, s); } + array irfftn(const array& a, StreamOrDevice s /* = {} */) { return fft_impl(a, true, true, s); } @@ -308,4 +309,73 @@ array istft( return signal; } +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + shifts.push_back(a.shape(axis) / 2); + } + + return roll(a, shifts, axes, s); +} + +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + int size = a.shape(axis); + shifts.push_back(-(size / 2)); + } + + return roll(a, shifts, axes, s); +} + +// Default versions that operate on all axes +array fftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return fftshift(a, axes, s); +} + +array ifftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return ifftshift(a, axes, s); +} } // namespace mlx::core::fft \ No newline at end of file diff --git a/mlx/fft.h b/mlx/fft.h index 1ccdf300d..b329c454c 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -148,6 +148,24 @@ inline array irfft2( StreamOrDevice s = {}) { return irfftn(a, axes, s); } +/** Shift the zero-frequency component to the center of the spectrum. */ +array fftshift(const array& a, StreamOrDevice s = {}); + +/** Shift the zero-frequency component to the center of the spectrum along + * specified axes. */ +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** The inverse of fftshift. */ +array ifftshift(const array& a, StreamOrDevice s = {}); + +/** The inverse of fftshift along specified axes. */ +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); inline array stft( const array& x, diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 59d91c007..2f9053f4d 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -335,7 +335,10 @@ ThreadPool& thread_pool() { return pool_; } -ThreadPool ParallelFileReader::thread_pool_{4}; +ThreadPool& ParallelFileReader::thread_pool() { + static ThreadPool thread_pool{4}; + return thread_pool; +} void ParallelFileReader::read(char* data, size_t n) { while (n != 0) { @@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) { break; } else { size_t m = batch_size_; - futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); + futs.emplace_back( + ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data)); data += m; n -= m; offset += m; diff --git a/mlx/io/load.h b/mlx/io/load.h index 138098e82..8b5dd95b6 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -101,7 +101,7 @@ class ParallelFileReader : public Reader { private: static constexpr size_t batch_size_ = 1 << 25; - static ThreadPool thread_pool_; + static ThreadPool& thread_pool(); int fd_; std::string label_; }; diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 5b9b51ad3..53f13486a 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) { // Prepare S S = expand_dims(S, -2, s); - return matmul(divide(V, S, s), U); + auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps; + auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s); + auto rS = + where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s); + + return matmul(multiply(V, rS, s), U, s); } array cholesky_inv( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 54ac62fef..4aa5e88b7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -473,8 +473,19 @@ array hadamard_transform( std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) - float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); + int n = a.ndim() > 0 ? a.shape(-1) : 1; + float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; + + // Nothing to do for a scalar + if (n == 1) { + if (scale == 1) { + return a; + } + + return multiply(a, array(scale, dtype), s); + } + return array( a.shape(), dtype, @@ -3769,6 +3780,7 @@ array conv_transpose_general( std::vector stride, std::vector padding, std::vector dilation, + std::vector output_padding, int groups, StreamOrDevice s) { std::vector padding_lo(padding.size()); @@ -3782,7 +3794,8 @@ array conv_transpose_general( int in_size = 1 + (conv_output_shape - 1); int out_size = 1 + stride[i] * (input.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding[i]; + padding_hi[i] = in_size - out_size + padding[i] + + output_padding[i]; // Adjust with output_padding } return conv_general( @@ -3805,10 +3818,11 @@ array conv_transpose1d( int stride /* = 1 */, int padding /* = 0 */, int dilation /* = 1 */, + int output_padding /* = 0 */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( - in_, wt_, {stride}, {padding}, {dilation}, groups, s); + in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s); } /** 2D transposed convolution with a filter */ @@ -3818,6 +3832,7 @@ array conv_transpose2d( const std::pair& stride /* = {1, 1} */, const std::pair& padding /* = {0, 0} */, const std::pair& dilation /* = {1, 1} */, + const std::pair& output_padding /* = {0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3826,6 +3841,7 @@ array conv_transpose2d( {stride.first, stride.second}, {padding.first, padding.second}, {dilation.first, dilation.second}, + {output_padding.first, output_padding.second}, groups, s); } @@ -3837,6 +3853,7 @@ array conv_transpose3d( const std::tuple& stride /* = {1, 1, 1} */, const std::tuple& padding /* = {0, 0, 0} */, const std::tuple& dilation /* = {1, 1, 1} */, + const std::tuple& output_padding /* = {0, 0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3845,6 +3862,9 @@ array conv_transpose3d( {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, + {std::get<0>(output_padding), + std::get<1>(output_padding), + std::get<2>(output_padding)}, groups, s); } @@ -4873,8 +4893,9 @@ array bitwise_impl( const array& b, BitwiseBinary::Op op, const std::string& op_name, - const StreamOrDevice& s) { - auto out_type = promote_types(a.dtype(), b.dtype()); + const StreamOrDevice& s, + std::optional out_type_ = std::nullopt) { + auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype()); if (!(issubdtype(out_type, integer) || out_type == bool_)) { std::ostringstream msg; msg << "[" << op_name @@ -4919,12 +4940,7 @@ array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { if (t == bool_) { t = uint8; } - return bitwise_impl( - astype(a, t, s), - astype(b, t, s), - BitwiseBinary::Op::LeftShift, - "left_shift", - s); + return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t); } array operator<<(const array& a, const array& b) { return left_shift(a, b); @@ -4940,7 +4956,8 @@ array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { astype(b, t, s), BitwiseBinary::Op::RightShift, "right_shift", - s); + s, + t); } array operator>>(const array& a, const array& b) { return right_shift(a, b); @@ -5019,8 +5036,11 @@ array roll( } auto sh = shift[i]; - auto split_index = - (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); + auto size = a.shape(ax); + if (size == 0) { + continue; // skip rolling this axis if it has size 0 + } + auto split_index = (sh < 0) ? (-sh) % size : size - sh % size; auto parts = split(result, Shape{split_index}, ax, s); std::swap(parts[0], parts[1]); diff --git a/mlx/ops.h b/mlx/ops.h index e79ea235d..af3cdb5bd 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -569,7 +569,7 @@ inline array std(const array& a, StreamOrDevice s = {}) { return std(a, false, 0, to_stream(s)); } -/** Computes the standard deviatoin of the elements of an array along the given +/** Computes the standard deviation of the elements of an array along the given * axes */ array std( const array& a, @@ -1291,6 +1291,7 @@ array conv_transpose1d( int stride = 1, int padding = 0, int dilation = 1, + int output_padding = 0, int groups = 1, StreamOrDevice s = {}); @@ -1301,6 +1302,7 @@ array conv_transpose2d( const std::pair& stride = {1, 1}, const std::pair& padding = {0, 0}, const std::pair& dilation = {1, 1}, + const std::pair& output_padding = {0, 0}, int groups = 1, StreamOrDevice s = {}); @@ -1311,6 +1313,7 @@ array conv_transpose3d( const std::tuple& stride = {1, 1, 1}, const std::tuple& padding = {0, 0, 0}, const std::tuple& dilation = {1, 1, 1}, + const std::tuple& output_padding = {0, 0, 0}, int groups = 1, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3d36f0881..7288a4885 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3056,6 +3056,7 @@ std::vector QuantizedMatmul::vjp( std::vector vjps; // We rely on the fact that w is always 2D so transpose is simple + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { @@ -3071,9 +3072,34 @@ std::vector QuantizedMatmul::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); + "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); + } else { + if (!dsb) { + auto fc = flatten(cotangents[0], 0, -2, stream()); + auto fx = flatten(primals[0], 0, -2, stream()); + auto dw = transpose_ + ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) + : matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); + dsb = unflatten(dw, -1, {-1, group_size_}, stream()); + } + if (arg == 3) { + // biases + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + // scales + auto s = stream(); + auto wq = dequantize( + primals[1], + ones_like(primals[2], stream()), + zeros_like(primals[3], stream()), + group_size_, + bits_, + stream()); + wq = unflatten(wq, -1, {-1, group_size_}, stream()); + vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); + } } } return vjps; diff --git a/mlx/random.cpp b/mlx/random.cpp index d6ce5bb0e..89a027b17 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -223,7 +223,7 @@ array multivariate_normal( auto n = mean.shape(-1); - // Check shapes comatibility of mean and cov + // Check shapes compatibility of mean and cov if (cov.shape(-1) != cov.shape(-2)) { throw std::invalid_argument( "[multivariate_normal] last two dimensions of cov must be equal."); @@ -402,7 +402,7 @@ array categorical( if (broadcast_shapes(shape, reduced_shape) != shape) { std::ostringstream msg; msg << "[categorical] Requested shape " << shape - << " is not broadcast compatable with reduced logits shape" + << " is not broadcast compatible with reduced logits shape" << reduced_shape << "."; throw std::invalid_argument(msg.str()); } diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 7bd128c10..b19f6434a 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -1,12 +1,13 @@ // Copyright © 2023 Apple Inc. #include "mlx/scheduler.h" -#include "mlx/backend/metal/metal.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" namespace mlx::core { Stream default_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[default_stream] Cannot get gpu stream without gpu backend."); } @@ -14,7 +15,7 @@ Stream default_stream(Device d) { } void set_default_stream(Stream s) { - if (!metal::is_available() && s.device == Device::gpu) { + if (!gpu::is_available() && s.device == Device::gpu) { throw std::invalid_argument( "[set_default_stream] Cannot set gpu stream without gpu backend."); } @@ -26,7 +27,7 @@ Stream get_stream(int index) { } Stream new_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[new_stream] Cannot make gpu stream without gpu backend."); } @@ -44,7 +45,7 @@ void synchronize(Stream s) { scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); f.wait(); } else { - metal::synchronize(s); + gpu::synchronize(s); } } diff --git a/mlx/scheduler.h b/mlx/scheduler.h index b2c6b842b..877fdd5f6 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -8,8 +8,7 @@ #include #include -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/device.h" #include "mlx/stream.h" @@ -67,7 +66,7 @@ struct StreamThread { class Scheduler { public: Scheduler() : n_active_tasks_(0) { - if (metal::is_available()) { + if (is_available(Device::gpu)) { default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); } default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); @@ -83,7 +82,7 @@ class Scheduler { streams_.emplace_back(streams_.size(), d); if (d == Device::gpu) { threads_.push_back(nullptr); - metal::new_stream(streams_.back()); + gpu::new_stream(streams_.back()); } else { threads_.push_back(new StreamThread{}); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index b305257f0..2d9942eda 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -10,7 +10,7 @@ #include #include "mlx/backend/cpu/eval.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/fence.h" #include "mlx/memory.h" #include "mlx/ops.h" @@ -42,7 +42,10 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector> detail::InTracing::trace_stack{}; +std::vector>& detail::InTracing::trace_stack() { + static std::vector> trace_stack_; + return trace_stack_; +} int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; @@ -215,7 +218,7 @@ array eval_impl(std::vector outputs, bool async) { } if (arr.primitive().device() == Device::gpu) { - metal::eval(arr); + gpu::eval(arr); } else { cpu::eval(arr); } @@ -226,7 +229,7 @@ array eval_impl(std::vector outputs, bool async) { // Commit any open streams for (auto& [_, e] : events) { if (e.stream().device == Device::gpu) { - metal::finalize(e.stream()); + gpu::finalize(e.stream()); } } scheduler::wait_for_one(); @@ -264,7 +267,7 @@ array eval_impl(std::vector outputs, bool async) { auto s = e.stream(); e.signal(s); if (s.device == Device::gpu) { - metal::finalize(s); + gpu::finalize(s); } } diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 7f62c406b..46851fa3d 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -22,19 +22,19 @@ std::vector vmap_replace( struct InTracing { explicit InTracing(bool dynamic = false, bool grad = false) { grad_counter += grad; - trace_stack.push_back({dynamic, grad}); + trace_stack().push_back({dynamic, grad}); } ~InTracing() { - grad_counter -= trace_stack.back().second; - trace_stack.pop_back(); + grad_counter -= trace_stack().back().second; + trace_stack().pop_back(); } static bool in_tracing() { - return !trace_stack.empty(); + return !trace_stack().empty(); } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front().first; + return in_tracing() && trace_stack().front().first; } static bool in_grad_tracing() { @@ -43,7 +43,7 @@ struct InTracing { private: static int grad_counter; - static std::vector> trace_stack; + static std::vector>& trace_stack(); }; struct RetainGraph { diff --git a/mlx/types/limits.h b/mlx/types/limits.h index 7e0de15bc..5f2b1e9e0 100644 --- a/mlx/types/limits.h +++ b/mlx/types/limits.h @@ -33,6 +33,9 @@ struct numeric_limits { static constexpr float16_t max() { return bits_to_half(0x7BFF); } + static constexpr float16_t epsilon() { + return bits_to_half(0x1400); + } static constexpr float16_t infinity() { return bits_to_half(0x7C00); } @@ -56,6 +59,9 @@ struct numeric_limits { static constexpr bfloat16_t max() { return bits_to_bfloat(0x7F7F); } + static constexpr bfloat16_t epsilon() { + return bits_to_bfloat(0x3C00); + } static constexpr bfloat16_t infinity() { return bits_to_bfloat(0x7F80); } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 188584174..0b2e66352 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -283,9 +283,10 @@ int get_var(const char* name, int default_value) { } // namespace env template -void set_finfo_limits(double& min, double& max) { +void set_finfo_limits(double& min, double& max, double& eps) { min = numeric_limits::lowest(); max = numeric_limits::max(); + eps = numeric_limits::epsilon(); } finfo::finfo(Dtype dtype) : dtype(dtype) { @@ -295,16 +296,16 @@ finfo::finfo(Dtype dtype) : dtype(dtype) { throw std::invalid_argument(msg.str()); } if (dtype == float32) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == bfloat16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float64) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == complex64) { this->dtype = float32; - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } } diff --git a/mlx/utils.h b/mlx/utils.h index 19241e4c6..f0aa7c2de 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -65,6 +65,7 @@ struct finfo { Dtype dtype; double min; double max; + double eps; }; /** Holds information about integral types. */ diff --git a/mlx/version.h b/mlx/version.h index fe47d96cc..8340e1e8c 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 25 -#define MLX_VERSION_PATCH 0 +#define MLX_VERSION_PATCH 1 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 9c946005b..404ecc349 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command): # Repeat the stdout and stderr to the local machine to_read = [p.stdout.fileno(), p.stderr.fileno()] - to_write = [p.stdin.fileno()] + to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()] pidfile = "" stdin_buffer = b"" + stdout_buffer = b"" + stderr_buffer = b"" while p.poll() is None: try: stdin_buffer += input_queue.get_nowait() @@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command): pass rlist, wlist, _ = select(to_read, to_write, [], 1.0) for fd in rlist: - is_stdout = fd == p.stdout.fileno() - outfile = sys.stdout if is_stdout else sys.stderr msg = os.read(fd, 8192).decode(errors="ignore") # Fetch the PID file first if we haven't already @@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command): pidfile, *msg = msg.split("\n", maxsplit=1) msg = msg[0] if msg else "" - outfile.write(msg) - outfile.flush() + is_stdout = fd == p.stdout.fileno() + if is_stdout: + stdout_buffer += msg.encode() + else: + stderr_buffer += msg.encode() for fd in wlist: - if len(stdin_buffer) > 0: + if fd == p.stdin.fileno() and len(stdin_buffer) > 0: n = os.write(fd, stdin_buffer) stdin_buffer = stdin_buffer[n:] + elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0: + n = os.write(fd, stdout_buffer) + stdout_buffer = stdout_buffer[n:] + elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0: + n = os.write(fd, stderr_buffer) + stderr_buffer = stderr_buffer[n:] if stop: p.terminate() break diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py index edacab061..a11c4cb40 100644 --- a/python/mlx/nn/layers/convolution_transpose.py +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -25,6 +25,8 @@ class ConvTranspose1d(Module): padding (int, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int, optional): The dilation of the convolution. + output_padding(int, optional): Additional size added to one side of the + output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -37,6 +39,7 @@ class ConvTranspose1d(Module): stride: int = 1, padding: int = 0, dilation: int = 1, + output_padding: int = 0, bias: bool = True, ): super().__init__() @@ -53,18 +56,25 @@ class ConvTranspose1d(Module): self.padding = padding self.dilation = dilation self.stride = stride + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose1d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -90,6 +100,8 @@ class ConvTranspose2d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -102,13 +114,14 @@ class ConvTranspose2d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) self.weight = mx.random.uniform( @@ -122,18 +135,25 @@ class ConvTranspose2d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose2d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -160,6 +180,8 @@ class ConvTranspose3d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -172,13 +194,14 @@ class ConvTranspose3d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt( 1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) @@ -194,18 +217,25 @@ class ConvTranspose3d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose3d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 823a0084f..2d6dc0882 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -193,12 +193,6 @@ class QuantizedLinear(Module): # Freeze this model's parameters self.freeze() - def unfreeze(self, *args, **kwargs): - """Wrap unfreeze so that we unfreeze any layers we might contain but - our parameters will remain frozen.""" - super().unfreeze(*args, **kwargs) - self.freeze(recurse=False) - def _extra_repr(self): out_dims, in_dims = self.weight.shape in_dims *= 32 // self.bits diff --git a/python/src/array.cpp b/python/src/array.cpp index 467bd0fa5..5f8dbe021 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -197,6 +197,13 @@ void init_array(nb::module_& m) { "max", &mx::finfo::max, R"pbdoc(The largest representable number.)pbdoc") + .def_ro( + "eps", + &mx::finfo::eps, + R"pbdoc( + The difference between 1.0 and the next smallest + representable number larger than 1.0. + )pbdoc") .def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") .def("__repr__", [](const mx::finfo& f) { std::ostringstream os; diff --git a/python/src/fft.cpp b/python/src/fft.cpp index aadb3893f..5d5f0dbc2 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -554,5 +554,55 @@ void init_fft(nb::module_& parent_module) { Returns: array: The reconstructed signal. + m.def( + "fftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::fftshift(a, axes.value(), s); + } else { + return mx::fft::fftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + Shift the zero-frequency component to the center of the spectrum. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the shift. + If ``None``, shift all axes. + + Returns: + array: The shifted array with the same shape as the input. + )pbdoc"); + m.def( + "ifftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::ifftshift(a, axes.value(), s); + } else { + return mx::fft::ifftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes, + the behavior differs for odd-length axes. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the inverse shift. + If ``None``, shift all axes. + + Returns: + array: The inverse-shifted array with the same shape as the input. )pbdoc"); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f98aa80aa..a1e77d681 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3455,8 +3455,8 @@ void init_ops(nb::module_& m) { 1D convolution over an input with several channels Args: - input (array): Input array of shape ``(N, H, C_in)``. - weight (array): Weight array of shape ``(C_out, H, C_in)``. + input (array): Input array of shape ``(N, L, C_in)``. + weight (array): Weight array of shape ``(C_out, K, C_in)``. stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. @@ -3514,7 +3514,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3586,7 +3586,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, D, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3609,20 +3609,22 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 1D transposed convolution over an input with several channels Args: - input (array): Input array of shape ``(N, H, C_in)``. - weight (array): Weight array of shape ``(C_out, H, C_in)``. + input (array): Input array of shape ``(N, L, C_in)``. + weight (array): Weight array of shape ``(C_out, K, C_in)``. stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. + output_padding (int, optional): Output padding. Default: ``0``. groups (int, optional): Input feature groups. Default: ``1``. Returns: @@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; + std::pair output_padding_pair{0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_pair = std::pair{*pv, *pv}; @@ -3659,19 +3663,33 @@ void init_ops(nb::module_& m) { dilation_pair = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_pair = std::pair{*pv, *pv}; + } else { + output_padding_pair = std::get>(output_padding); + } + return mx::conv_transpose2d( - input, weight, stride_pair, padding_pair, dilation_pair, groups, s); + input, + weight, + stride_pair, + padding_pair, + dilation_pair, + output_padding_pair, + groups, + s); }, nb::arg(), nb::arg(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 2D transposed convolution over an input with several channels @@ -3679,7 +3697,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::tuple stride_tuple{1, 1, 1}; std::tuple padding_tuple{0, 0, 0}; std::tuple dilation_tuple{1, 1, 1}; + std::tuple output_padding_tuple{0, 0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_tuple = std::tuple{*pv, *pv, *pv}; @@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) { dilation_tuple = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_tuple = std::tuple{*pv, *pv, *pv}; + } else { + output_padding_tuple = + std::get>(output_padding); + } + return mx::conv_transpose3d( input, weight, stride_tuple, padding_tuple, dilation_tuple, + output_padding_tuple, groups, s); }, @@ -3739,11 +3770,12 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 3D transposed convolution over an input with several channels @@ -3751,7 +3783,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, D, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -5189,4 +5224,46 @@ void init_ops(nb::module_& m) { Returns: array: The row or col contiguous output. )pbdoc"); + m.def( + "broadcast_shapes", + [](const nb::args& shapes) { + if (shapes.size() == 0) + throw std::invalid_argument( + "[broadcast_shapes] Must provide at least one shape."); + + mx::Shape result = nb::cast(shapes[0]); + for (size_t i = 1; i < shapes.size(); ++i) { + if (!nb::isinstance(shapes[i]) && + !nb::isinstance(shapes[i])) + throw std::invalid_argument( + "[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints)."); + result = mx::broadcast_shapes(result, nb::cast(shapes[i])); + } + + return nb::tuple(nb::cast(result)); + }, + nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]"), + R"pbdoc( + Broadcast shapes. + + Returns the shape that results from broadcasting the supplied array shapes + against each other. + + Args: + *shapes (Sequence[int]): The shapes to broadcast. + + Returns: + tuple: The broadcasted shape. + + Raises: + ValueError: If the shapes cannot be broadcast. + + Example: + >>> mx.broadcast_shapes((1,), (3, 1)) + (3, 1) + >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,)) + (5, 6, 7) + >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1)) + (5, 3, 4) + )pbdoc"); } diff --git a/python/src/random.cpp b/python/src/random.cpp index e9c0a87fc..22b706174 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -422,7 +422,7 @@ void init_random(nb::module_& parent_module) { axis (int, optional): The axis which specifies the distribution. Default: ``-1``. shape (list(int), optional): The shape of the output. This must - be broadcast compatable with ``logits.shape`` with the ``axis`` + be broadcast compatible with ``logits.shape`` with the ``axis`` dimension removed. Default: ``None`` num_samples (int, optional): The number of samples to draw from each of the categorical distributions in ``logits``. The output will have diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fa5784ea9..792e666d6 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -103,10 +103,12 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min) self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max) + self.assertEqual(mx.finfo(mx.float32).eps, np.finfo(np.float32).eps) self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32) self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min) self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) + self.assertEqual(mx.finfo(mx.float16).eps, np.finfo(np.float16).eps) self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) def test_iinfo(self): diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 1ac20cbb1..2085e09d7 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase): N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_tranpose_1d_output_padding(self): + def run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5 + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) + wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) + + out_mx = mx.conv_transpose1d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose1d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for iH, kH, stride, padding, output_padding in ( + (3, 2, 2, 0, 1), + (5, 3, 2, 1, 0), + (7, 4, 3, 1, 2), + ): + run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_2d_output_padding(self): + def run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)) + wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)) + + out_mx = mx.conv_transpose2d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose2d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)), + ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)), + ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)), + ): + run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_3d_output_padding(self): + def run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) + wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) + + out_mx = mx.conv_transpose3d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.conv_transpose3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)), + ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)), + ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)), + ): + run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 2b4b425ca..0190827bd 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -242,6 +242,7 @@ class TestExportImport(mlx_tests.MLXTestCase): def test_leaks(self): path = os.path.join(self.test_dir, "fn.mlxfn") + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: @@ -267,6 +268,24 @@ class TestExportImport(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_export_import_shapeless(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(*args): + return sum(args) + + with mx.exporter(path, fun, shapeless=True) as exporter: + exporter(mx.array(1)) + exporter(mx.array(1), mx.array(2)) + exporter(mx.array(1), mx.array(2), mx.array(3)) + + f2 = mx.import_function(path) + self.assertEqual(f2(mx.array(1))[0].item(), 1) + self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2) + self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3) + with self.assertRaises(ValueError): + f2(mx.array(10), mx.array([5, 10, 20])) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index d35a2b1da..8f55d41e3 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_vector_batched(self): + D = 64 + q = mx.random.normal(shape=(2, 1, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2) + v = mx.random.normal(shape=(2, 2, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + ref = mlx_ref_attn(q, k, v, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + class TestSDPA(mlx_tests.MLXTestCase): @property diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index c887cd968..f644944c7 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -199,6 +199,68 @@ class TestFFT(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.fft.irfftn(x) + def test_fftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c) + + def test_ifftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c) + + def test_fftshift_errors(self): + # Test invalid axes + x = mx.array(np.random.rand(4, 4).astype(np.float32)) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[2]) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[-3]) + + # Test empty array + x = mx.array([]) + self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ffa355c10..a9fe572af 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -232,6 +232,11 @@ class TestLinalg(mlx_tests.MLXTestCase): for M, M_plus in zip(AB, pinvs): self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3)) + # Test singular matrix + A = mx.array([[4.0, 1.0], [4.0, 1.0]]) + A_plus = mx.linalg.pinv(A, stream=mx.cpu) + self.assertTrue(mx.allclose(A @ A_plus @ A, A)) + def test_cholesky_inv(self): mx.random.seed(7) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9cfa25dae..826d53d96 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -8,7 +8,7 @@ import mlx.core as mx import mlx.nn as nn import mlx_tests import numpy as np -from mlx.utils import tree_flatten, tree_map +from mlx.utils import tree_flatten, tree_map, tree_reduce class TestBase(mlx_tests.MLXTestCase): @@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase): self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + def test_quantize_freeze(self): + lin = nn.Linear(512, 512) + qlin = lin.to_quantized() + qlin.unfreeze(keys=["scales"]) + size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0) + self.assertTrue(size > 0) + def test_grad_of_module(self): class Model(nn.Module): def __init__(self): diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 31ea79345..d9e143d82 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -10,6 +10,47 @@ import mlx_tests import numpy as np +def np_wrap_between(x, a): + """Wraps `x` between `[-a, a]`.""" + two_a = 2 * a + zero = 0 + rem = np.remainder(np.add(x, a), two_a) + if isinstance(rem, np.ndarray): + rem = np.select(rem < zero, np.add(rem, two_a), rem) + else: + rem = np.add(rem, two_a) if rem < zero else rem + return np.subtract(rem, a) + + +def np_logaddexp(x1: np.ndarray, x2: np.ndarray): + amax = np.maximum(x1, x2) + if np.issubdtype(x1.dtype, np.floating): + delta = np.subtract(x1, x2) + if isinstance(delta, np.ndarray): + return np.select( + np.isnan(delta), + np.add(x1, x2), + np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))), + ) + else: + return ( + np.add(x1, x2) + if np.isnan(delta) + else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))) + ) + else: + delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2)) + out = np.add(amax, np.log1p(np.exp(delta))) + return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi) + + +def np_cumlogaddexp(x1: np.ndarray, axis: int = -1): + out = x1.copy() + for i in range(1, out.shape[axis]): + out[i] = np_logaddexp(out[i], out[i - 1]) + return out + + class TestOps(mlx_tests.MLXTestCase): def test_full_ones_zeros(self): x = mx.full(2, 3.0) @@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + + a = mx.array([0, 1, 2, 9.0]) + 1j + b = mx.array([1, 0, 4, 2.5]) + 1j + + result = mx.logaddexp(a, b) + expected = np_logaddexp(np.array(a), np.array(b)) + + self.assertTrue(np.allclose(result, expected)) + a = mx.array([float("nan")]) b = mx.array([0.0]) self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) @@ -977,6 +1028,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + a = mx.array([1, 0.5, 10, 100]) + 1j + result = mx.log1p(a) + expected = np.log1p(a, dtype=np.complex64) + + self.assertTrue(np.allclose(result, expected)) + def test_sigmoid(self): a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0]) result = mx.sigmoid(a) @@ -1881,10 +1939,31 @@ class TestOps(mlx_tests.MLXTestCase): c_mlx = mxop(a_mlx, axis=0) self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + # Complex tests + + a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j + a_mlx = mx.array(a_npy) + c_npy = np_cumlogaddexp(a_npy, axis=-1) + c_mlx = mxop(a_mlx, axis=-1) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: + npop = getattr(np, op) + mxop = getattr(mx, op) + for axis in (None, 0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + # Complex test + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j + a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: npop = getattr(np, op) mxop = getattr(mx, op) @@ -2789,11 +2868,33 @@ class TestOps(mlx_tests.MLXTestCase): h28 = parse_h_string(h28_str) + x = mx.array(5) + y = mx.hadamard_transform(x) + self.assertEqual(y.item(), 5) + + x = mx.array(5) + y = mx.hadamard_transform(x, scale=0.2) + self.assertEqual(y.item(), 1) + + x = mx.random.normal((8, 8, 1)) + y = mx.hadamard_transform(x) + self.assertTrue(mx.all(y == x).item()) + + # Too slow to compare to numpy so let's compare CPU to GPU + if mx.default_device() == mx.gpu: + rk = mx.random.key(42) + for k in range(14, 17): + for m in [1, 3, 5, 7]: + x = mx.random.normal((4, m * 2**k), key=rk) + y1 = mx.hadamard_transform(x, stream=mx.cpu) + y2 = mx.hadamard_transform(x, stream=mx.gpu) + self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6) + np.random.seed(7) - tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15)) + tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14)) for dtype, m, k in tests: # skip large m=28 cases because they're very slow in NumPy - if (m > 1 and k > 8) or (dtype != np.float16 and k == 14): + if m > 1 and k > 8: continue with self.subTest(dtype=dtype, m=m, k=k): n = m * 2**k @@ -2882,6 +2983,11 @@ class TestOps(mlx_tests.MLXTestCase): y2 = mx.roll(x, s, a) self.assertTrue(mx.array_equal(y1, y2).item()) + def test_roll_errors(self): + x = mx.array([]) + result = mx.roll(x, [0], [0]) + self.assertTrue(mx.array_equal(result, x)) + def test_real_imag(self): x = mx.random.uniform(shape=(4, 4)) out = mx.real(x) @@ -2964,5 +3070,45 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) +class TestBroadcast(mlx_tests.MLXTestCase): + def test_broadcast_shapes(self): + # Basic broadcasting + self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) + self.assertEqual(mx.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) + self.assertEqual(mx.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) + + # Multiple arguments + self.assertEqual(mx.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) + self.assertEqual( + mx.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) + ) + + # Same shapes + self.assertEqual(mx.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) + + # Single argument + self.assertEqual(mx.broadcast_shapes((2, 3)), (2, 3)) + + # Empty shapes + self.assertEqual(mx.broadcast_shapes((), ()), ()) + self.assertEqual(mx.broadcast_shapes((), (1,)), (1,)) + self.assertEqual(mx.broadcast_shapes((1,), ()), (1,)) + + # Broadcasting with zeroes + self.assertEqual(mx.broadcast_shapes((0,), (0,)), (0,)) + self.assertEqual(mx.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) + self.assertEqual(mx.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) + + # Error cases + with self.assertRaises(ValueError): + mx.broadcast_shapes((3, 4), (4, 3)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes((2, 3, 4), (2, 5, 4)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes() + + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index eeefcd94f..60ab421c6 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -549,6 +549,31 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + def test_vjp_scales_biases(self): + mx.random.seed(0) + x = mx.random.normal(shape=(2, 2, 512)) + w = mx.random.normal(shape=(512, 512)) + wq, s, b = mx.quantize(w, bits=4, group_size=64) + + def mm(sb, x, wq): + return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum() + + params = (s, b) + dparams = mx.grad(mm)((s, b), x, wq) + + eps = 8e-3 + # numerical grad check with a few indices + indices = [(0, 0), (11, 4), (22, 7)] + for idx in indices: + for p in [0, 1]: + params[p][idx] += eps + out_up = mm(params, x, wq) + params[p][idx] -= 2 * eps + out_down = mm(params, x, wq) + params[p][idx] += eps + num_ds = (out_up - out_down) / (2 * eps) + self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2) + if __name__ == "__main__": unittest.main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index be4479e70..cf0ba3d5d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,7 +10,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) if(MLX_BUILD_METAL) - set(METAL_TEST_SOURCES metal_tests.cpp) + set(METAL_TEST_SOURCES gpu_tests.cpp) endif() include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 66511682d..96552ef9d 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") { out = cfun2({array(0)}); CHECK_EQ(out[0].item(), 3); } + +TEST_CASE("test compile with no-ops") { + auto fun = [](const std::vector& inputs) { + return std::vector{abs(stop_gradient(abs(inputs[0])))}; + }; + auto in = array(1.0); + auto out = compile(fun)({in})[0]; + CHECK_EQ(out.inputs()[0].id(), in.id()); +} diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index b0d8c8e52..4373e2920 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -309,6 +309,7 @@ TEST_CASE("test fft grads") { CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } +<<<<<<< HEAD TEST_CASE("test stft and istft") { int n_fft = 4; int hop_length = 2; @@ -381,4 +382,62 @@ TEST_CASE("test stft and istft") { CHECK_EQ(stft_result.shape(1), n_fft); } -} \ No newline at end of file +} +== == == = TEST_CASE("test fftshift and ifftshift") { + // Test 1D array with even length + auto x = arange(8); + auto y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + // print y + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test 1D array with odd length + x = arange(7); + y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item()); + + // Test 2D array + x = reshape(arange(16), {4, 4}); + y = fft::fftshift(x); + auto expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test with specific axes + y = fft::fftshift(x, {0}); + expected = + array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + y = fft::fftshift(x, {1}); + expected = + array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test ifftshift (inverse operation) + x = arange(8); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test ifftshift with odd length (different from fftshift) + x = arange(7); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item()); + + // Test 2D ifftshift + x = reshape(arange(16), {4, 4}); + y = fft::ifftshift(x); + expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test error cases + CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument); +} +>>>>>>> 5a1a5d5ed16f69af7c3ce56dd94e4502661e1565 diff --git a/tests/metal_tests.cpp b/tests/gpu_tests.cpp similarity index 95% rename from tests/metal_tests.cpp rename to tests/gpu_tests.cpp index 7aabdf36d..f0ef969cf 100644 --- a/tests/metal_tests.cpp +++ b/tests/gpu_tests.cpp @@ -1,11 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "doctest/doctest.h" -#include "mlx/backend/metal/allocator.h" -#include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal.h" +#include "doctest/doctest.h" #include "mlx/mlx.h" using namespace mlx::core; @@ -13,13 +10,7 @@ using namespace mlx::core; static const std::array types = {bool_, uint32, int32, int64, float32}; -TEST_CASE("test metal device") { - // Make sure the device and library can load - CHECK(metal::is_available()); - auto& device = metal::device(Device::gpu); -} - -TEST_CASE("test metal arange") { +TEST_CASE("test gpu arange") { for (auto t : types) { if (t == bool_) { continue; @@ -34,7 +25,7 @@ TEST_CASE("test metal arange") { } } -TEST_CASE("test metal full") { +TEST_CASE("test gpu full") { for (auto t : types) { auto out_cpu = full({4, 4}, 2, t, Device::cpu); auto out_gpu = full({4, 4}, 2, t, Device::gpu); @@ -63,7 +54,7 @@ TEST_CASE("test metal full") { } } -TEST_CASE("test metal astype") { +TEST_CASE("test gpu astype") { array x = array({-4, -3, -2, -1, 0, 1, 2, 3}); // Check all types work for (auto t : types) { @@ -80,7 +71,7 @@ TEST_CASE("test metal astype") { } } -TEST_CASE("test metal reshape") { +TEST_CASE("test gpu reshape") { array x = array({0, 1, 2, 3, 4, 5, 6, 7}); auto out_cpu = reshape(x, {2, 2, 2}); auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu); @@ -96,7 +87,7 @@ TEST_CASE("test metal reshape") { CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); } -TEST_CASE("test metal reduce") { +TEST_CASE("test gpu reduce") { { array a(true); CHECK_EQ(all(a, Device::gpu).item(), true); @@ -190,7 +181,7 @@ TEST_CASE("test metal reduce") { } } -TEST_CASE("test metal binary ops") { +TEST_CASE("test gpu binary ops") { // scalar-scalar { array a(2.0f); @@ -338,7 +329,7 @@ TEST_CASE("test metal binary ops") { } } -TEST_CASE("test metal unary ops") { +TEST_CASE("test gpu unary ops") { // contiguous { array x({-1.0f, 0.0f, 1.0f}); @@ -392,7 +383,7 @@ TEST_CASE("test metal unary ops") { } } -TEST_CASE("test metal random") { +TEST_CASE("test gpu random") { { auto key = random::key(0); auto x = random::bits({}, 4, key, Device::gpu); @@ -415,7 +406,7 @@ TEST_CASE("test metal random") { } } -TEST_CASE("test metal matmul") { +TEST_CASE("test gpu matmul") { { auto a = ones({2, 2}); auto b = ones({2, 2}); @@ -440,7 +431,7 @@ TEST_CASE("test metal matmul") { } } -TEST_CASE("test metal validation") { +TEST_CASE("test gpu validation") { // Run this test with Metal validation enabled // METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \ // -tc="test metal validation" \ diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index de0f3352c..5e2bae5a0 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3859,6 +3859,9 @@ TEST_CASE("test roll") { y = roll(x, {1, 2}, {0, 1}); CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) .item()); + + y = roll(array({}), 0, 0); + CHECK(array_equal(y, array({})).item()); } TEST_CASE("test contiguous") { @@ -3911,4 +3914,70 @@ TEST_CASE("test bitwise shift operations") { CHECK_EQ(right_shift_bool_result.dtype(), uint8); CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); -} \ No newline at end of file +} + +TEST_CASE("test conv_transpose1d with output_padding") { + auto in = array({1.0, 2.0, 3.0}, {1, 1, 3}); + auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3}); + int stride = 2; + int padding = 0; + int dilation = 1; + int output_padding = 1; + int groups = 1; + + auto out = conv_transpose1d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array({6.0, 0.0}, {1, 2, 1}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose2d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2}); + auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2}); + std::pair stride{2, 2}; + std::pair padding{0, 0}; + std::pair output_padding{1, 1}; + std::pair dilation{1, 1}; + int groups = 1; + + auto out = conv_transpose2d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, + 3.0, + 0.0, + 0.0, + 7.0, + 7.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0}, + {1, 2, 4, 2}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose3d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2}); + auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2}); + std::tuple stride{2, 2, 2}; + std::tuple padding{0, 0, 0}; + std::tuple output_padding{1, 1, 1}; + std::tuple dilation{1, 1, 1}; + int groups = 1; + + auto out = conv_transpose3d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {1, 2, 4, 4, 1}); + CHECK(array_equal(out, expected).item()); +}