Compare commits

...

7 Commits

Author SHA1 Message Date
Anastasiia Filippova
012fb220a1 fp quantize (#2892)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 06:11:25 -08:00
Nathan Goldbaum
e1fee0074b Update nanobind pin to most recent version (#2896) 2025-12-11 06:07:36 -08:00
CCYeh
3c8ce9b00e Fix input buffer donation in compile (#2897) 2025-12-11 06:07:03 -08:00
David Koski
937ce79660 do not use simd neon intrinsics on x86 (#2893)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-10 12:23:28 -08:00
Nathan Goldbaum
208f5441a7 bump minimum required Python version (#2891)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-09 16:54:38 -08:00
Awni Hannun
b862d842e1 Allow events in sub graph to be updatable (#2886)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-09 12:34:37 -08:00
Satyam singh
f7a400951a Fix docs: replace mx.random.randn with mx.random.normal (#2890) 2025-12-09 11:46:30 -08:00
21 changed files with 551 additions and 98 deletions

View File

@@ -11,7 +11,7 @@ runs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs

View File

@@ -36,7 +36,7 @@ runs:
run: | run: |
python -m venv .venv python -m venv .venv
source .venv/bin/activate source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0 pip install setuptools cmake nanobind==2.10.2
echo PATH=$PATH >> $GITHUB_ENV echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind # Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV

View File

@@ -95,7 +95,7 @@ jobs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs
shell: bash -l {0} shell: bash -l {0}

View File

@@ -273,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package( find_package(
Python 3.8 Python 3.10
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
REQUIRED) REQUIRED)
execute_process( execute_process(

View File

@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
.. code-block:: shell .. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10) >>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1 >>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]`` The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``. selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.

View File

@@ -3,6 +3,6 @@ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.25", "cmake>=3.25",
"mlx>=0.18.0", "mlx>=0.18.0",
"nanobind==2.4.0", "nanobind==2.10.2",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.4.0 nanobind==2.10.2

View File

@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) { in.is_donatable() && !is_constant(i)) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) { !is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;

View File

@@ -3,5 +3,9 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE #ifdef MLX_USE_ACCELERATE
#if defined(__x86_64__)
// the accelerate_simd implementation require neon -- use base implementation
#else
#include "mlx/backend/cpu/simd/accelerate_simd.h" #include "mlx/backend/cpu/simd/accelerate_simd.h"
#endif #endif
#endif

View File

@@ -338,28 +338,40 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
} }
cudaGraphNodeType type; cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type)); CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) { switch (type) {
// Try to be updatable for a structure like graph -> graph -> kernel case cudaGraphNodeTypeGraph: {
cudaGraph_t child; // Try to be updatable for a structure like graph -> graph -> kernel
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child)); cudaGraph_t child;
auto [subkey, sub_is_updatable] = subgraph_to_key(child); CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
is_updatable &= sub_is_updatable; auto [subkey, sub_is_updatable] = subgraph_to_key(child);
key += subkey; is_updatable &= sub_is_updatable;
} else if (type == cudaGraphNodeTypeMemset) { key += subkey;
key += "M"; break;
} else if (type != cudaGraphNodeTypeKernel) {
is_updatable = false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
} }
case cudaGraphNodeTypeMemset:
key += "M";
break;
case cudaGraphNodeTypeKernel: {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
}
break;
}
case cudaGraphNodeTypeWaitEvent:
key += "W";
break;
case cudaGraphNodeTypeEventRecord:
key += "R";
break;
default:
is_updatable = false;
} }
} }
key += ")"; key += ")";

View File

@@ -2,7 +2,11 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
#include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -13,17 +17,6 @@
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits> template <int bits>
struct Dequantize { struct Dequantize {
__device__ float operator()(uint8_t x) { __device__ float operator()(uint8_t x) {
@@ -37,29 +30,40 @@ struct Dequantize {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
__global__ void __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { using Tx2 = Vector2_t<T>;
using Tx4 = Vector4_t<T>;
uint32_t rbits = 0; // reserved bits for future use
auto block_size = cg::this_thread_block().dim_threads(); auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index(); auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index(); auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
auto grid_dim_x = size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; size_t base_idx = thread_idx * group_size;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) { if (base_idx >= size) {
return; return;
} }
float w_thread = w[index]; auto w_tile = load_vector<group_size, T>(w, thread_idx);
float scale = 0.0f;
cg::greater<float> max_op; Tx2 amax_2x = Tx2{0.0f, 0.0f};
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
#pragma unroll
for (int i = 0; i < group_size; i += 2) {
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
}
scale = static_cast<float>(
max(fabsf(static_cast<float>(amax_2x.x)),
fabsf(static_cast<float>(amax_2x.y))));
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f; scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale // Convert to mx scale or nv scale
using ScaleType = using ScaleType =
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
uint8_t q_scale = s.__x; uint8_t q_scale = s.__x;
scale = float(s); scale = float(s);
// Write out the scales scales[thread_idx] = q_scale;
size_t gindex = index / group_size; constexpr int elem_per_byte = bits == 8 ? 1 : 2;
if (index % group_size == 0) { AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale); #pragma unroll
if (bits == 4) { for (int i = 0; i < group_size / 4; i++) {
uint8_t sval = warp.shfl_down(output, 1); Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
output |= sval << bits; if constexpr (bits == 8) {
} uint32_t quantized_val =
constexpr int pack_factor = bits == 8 ? 1 : 2; scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
if (index % pack_factor == 0) { *reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
out[index / pack_factor] = output; } else {
uint16_t quantized_val =
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
}
} }
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
} }
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale>
@@ -142,15 +149,16 @@ void fp_quantize(
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) { if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>; auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
if (bits == 8) { if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>; kernel = cu::fp_quantize<T, 32, 8, true, false>;
} else if (group_size == 16) { } else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>; kernel = cu::fp_quantize<T, 16, 4, false, false>;
} }
bool large = w.size() > UINT_MAX; bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large); get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -0,0 +1,32 @@
#pragma once
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// TODO implement fast path
template <typename T>
__device__ __forceinline__ uint32_t
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
uint32_t out_fp8x4 = 0;
float4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
return out_fp8x4;
}
// Place holder for future fast path implementation
template <typename T, bool USE_SR>
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,334 @@
#pragma once
#include <cuda.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using f32x4 = Vector4_t<float>;
template <typename T>
__device__ __forceinline__ uint16_t
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
// Fallback implementation for architectures that do not support cvt
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
uint16_t out_fp4x4 = 0;
fp32x4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
static_cast<uint16_t>(q0);
return out_fp4x4;
}
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)
__device__ __forceinline__ uint16_t
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t" // first bf16
".reg.b16 x1_bf16; \n\t" // second bf16
".reg.b16 x2_bf16; \n\t" // third bf16
".reg.b16 x3_bf16; \n\t" // fourth bf16
".reg.b32 x0; \n\t" // to hold scaled first
".reg.b32 x1; \n\t" // to hold scaled second
".reg.b32 x2; \n\t" // to hold scaled third
".reg.b32 x3; \n\t" // to hold scaled fourth
".reg.b64 x01; \n\t" // to hold vector mul
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
// pair
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
// pair
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(
scale))); // here cast is needed becuase an asm operand must have
// scalar type
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
const bf16x4 input_bf16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t"
".reg.b16 x1_bf16; \n\t"
".reg.b16 x2_bf16; \n\t"
".reg.b16 x3_bf16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
"cvt.f32.bf16 x0, x0_bf16; \n\t"
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
const fp16x4 input_fp16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
const bf16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
const fp16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
float2 input_fp32x2_0 = make_float2(input.x, input.y);
float2 input_fp32x2_1 = make_float2(input.z, input.w);
if constexpr (USE_SR) {
return scale_cvt_fp32x4_to_fp4x4_rs(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
} else {
return scale_cvt_fp32x4_to_fp4x4_rn(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
}
}
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else if constexpr (std::is_same<T, __half>::value) {
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else {
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
}
}
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
#else
static_assert(
!USE_SR,
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
#endif
}
} // namespace mlx::core::cu

View File

@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
} }
template <typename T>
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
if constexpr (
(std::is_same<T, __nv_bfloat162>::value) ||
(std::is_same<T, __half2>::value)) {
T a = x1;
T b = x2;
out = __hmax2(__habs2(a), __habs2(b));
} else if constexpr (std::is_same<T, float2>::value) {
float2 a = x1;
float2 b = x2;
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
}
}
} // namespace cu } // namespace cu
template <typename F> template <typename F>

View File

@@ -3,31 +3,10 @@
#pragma once #pragma once
#include "mlx/backend/cuda/steel/utils.cuh" #include "mlx/backend/cuda/steel/utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/** /**
* The basic building block for Ampere mmas. A 16x16 tile distributed across * The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp. * the warp.

View File

@@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core::cu {
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
template <typename T>
struct Vector4 {
T x, y, z, w;
};
template <typename T>
using Vector4_t = Vector4<T>;
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using fp32x4 = Vector4_t<float>;
} // namespace mlx::core::cu

View File

@@ -1,7 +1,7 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=80", "setuptools>=80",
"nanobind==2.4.0", "nanobind==2.10.2",
"cmake>=3.25", "cmake>=3.25",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -89,7 +89,8 @@ static PyType_Spec gc_func_spec = {
/* .name = */ "mlx.gc_func", /* .name = */ "mlx.gc_func",
/* .basicsize = */ (int)sizeof(gc_func), /* .basicsize = */ (int)sizeof(gc_func),
/* .itemsize = */ 0, /* .itemsize = */ 0,
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL, /* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_HAVE_VECTORCALL,
/* .slots = */ gc_func_slots}; /* .slots = */ gc_func_slots};
static PyTypeObject* gc_func_tp = nullptr; static PyTypeObject* gc_func_tp = nullptr;

View File

@@ -16,8 +16,7 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
NB_TYPE_CASTER( NB_TYPE_CASTER(
List, List,
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name + const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
const_name(", ...]"))
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
size_t size; size_t size;

View File

@@ -4,12 +4,12 @@ import gc
import inspect import inspect
import io import io
import math import math
import unittest
from functools import partial, wraps from functools import partial, wraps
from io import StringIO from io import StringIO
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
import numpy as np
class TestCompile(mlx_tests.MLXTestCase): class TestCompile(mlx_tests.MLXTestCase):
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
loss, grads = step(emb, w, x) loss, grads = step(emb, w, x)
mx.eval(loss, grads) mx.eval(loss, grads)
def test_compile_donates_input_buffer(self):
mx.set_default_device(mx.cpu)
def fun(x):
return mx.sin(x) + 1
compiled_fn = mx.compile(fun)
input = mx.arange(16, dtype=mx.float32)
mx.eval(input)
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
out = compiled_fn(input)
del input # Ensure the reference is dropped
mx.eval(out)
self.assertEqual(
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()

View File

@@ -255,7 +255,7 @@ if __name__ == "__main__":
extras = { extras = {
"dev": [ "dev": [
"nanobind==2.4.0", "nanobind==2.10.2",
"numpy", "numpy",
"pre-commit", "pre-commit",
"setuptools>=80", "setuptools>=80",