mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
937ce79660
...
012fb220a1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
012fb220a1 | ||
|
|
e1fee0074b | ||
|
|
3c8ce9b00e |
2
.github/actions/build-macos/action.yml
vendored
2
.github/actions/build-macos/action.yml
vendored
@@ -11,7 +11,7 @@ runs:
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install cmake setuptools nanobind==2.10.2
|
||||
pip install -e . -v
|
||||
|
||||
- name: Generate package stubs
|
||||
|
||||
2
.github/actions/setup-linux/action.yml
vendored
2
.github/actions/setup-linux/action.yml
vendored
@@ -36,7 +36,7 @@ runs:
|
||||
run: |
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install setuptools cmake nanobind==2.4.0
|
||||
pip install setuptools cmake nanobind==2.10.2
|
||||
echo PATH=$PATH >> $GITHUB_ENV
|
||||
# Make cmake search .venv for nanobind
|
||||
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -95,7 +95,7 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install cmake setuptools nanobind==2.10.2
|
||||
pip install -e . -v
|
||||
- name: Generate package stubs
|
||||
shell: bash -l {0}
|
||||
|
||||
@@ -3,6 +3,6 @@ requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.25",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.4.0",
|
||||
"nanobind==2.10.2",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.4.0
|
||||
nanobind==2.10.2
|
||||
|
||||
@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
in.is_donatable() && !is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
is_constant(i)) {
|
||||
!is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
|
||||
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
@@ -13,17 +17,6 @@
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
__device__ uint8_t operator()(float x) {
|
||||
if constexpr (bits == 8) {
|
||||
return __nv_fp8_e4m3(x).__x;
|
||||
} else {
|
||||
return __nv_fp4_e2m1(x).__x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
__device__ float operator()(uint8_t x) {
|
||||
@@ -37,29 +30,40 @@ struct Dequantize {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
__global__ void
|
||||
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
|
||||
__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
using Tx2 = Vector2_t<T>;
|
||||
using Tx4 = Vector4_t<T>;
|
||||
uint32_t rbits = 0; // reserved bits for future use
|
||||
auto block_size = cg::this_thread_block().dim_threads();
|
||||
auto block_idx = cg::this_thread_block().group_index();
|
||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
|
||||
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
size_t index = tidx + grid_dim_x * size_t(tidy);
|
||||
if (index >= size) {
|
||||
size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t base_idx = thread_idx * group_size;
|
||||
|
||||
if (base_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float w_thread = w[index];
|
||||
auto w_tile = load_vector<group_size, T>(w, thread_idx);
|
||||
float scale = 0.0f;
|
||||
|
||||
cg::greater<float> max_op;
|
||||
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||
Tx2 amax_2x = Tx2{0.0f, 0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < group_size; i += 2) {
|
||||
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
|
||||
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
|
||||
}
|
||||
|
||||
scale = static_cast<float>(
|
||||
max(fabsf(static_cast<float>(amax_2x.x)),
|
||||
fabsf(static_cast<float>(amax_2x.y))));
|
||||
|
||||
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||
// Convert to mx scale or nv scale
|
||||
using ScaleType =
|
||||
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||
uint8_t q_scale = s.__x;
|
||||
scale = float(s);
|
||||
|
||||
// Write out the scales
|
||||
size_t gindex = index / group_size;
|
||||
if (index % group_size == 0) {
|
||||
scales[gindex] = q_scale;
|
||||
}
|
||||
scales[thread_idx] = q_scale;
|
||||
constexpr int elem_per_byte = bits == 8 ? 1 : 2;
|
||||
AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
|
||||
|
||||
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||
if (bits == 4) {
|
||||
uint8_t sval = warp.shfl_down(output, 1);
|
||||
output |= sval << bits;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < group_size / 4; i++) {
|
||||
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
|
||||
if constexpr (bits == 8) {
|
||||
uint32_t quantized_val =
|
||||
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
||||
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
|
||||
} else {
|
||||
uint16_t quantized_val =
|
||||
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
||||
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
|
||||
}
|
||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||
if (index % pack_factor == 0) {
|
||||
out[index / pack_factor] = output;
|
||||
}
|
||||
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||
@@ -142,15 +149,16 @@ void fp_quantize(
|
||||
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (!std::is_same_v<T, double>) {
|
||||
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
||||
auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
|
||||
if (bits == 8) {
|
||||
kernel = cu::fp_quantize<T, 32, 8, true>;
|
||||
kernel = cu::fp_quantize<T, 32, 8, true, false>;
|
||||
} else if (group_size == 16) {
|
||||
kernel = cu::fp_quantize<T, 16, 4, false>;
|
||||
kernel = cu::fp_quantize<T, 16, 4, false, false>;
|
||||
}
|
||||
bool large = w.size() > UINT_MAX;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
||||
get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
|
||||
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
|
||||
32
mlx/backend/cuda/quantized/mxfp8_quantize.cuh
Normal file
32
mlx/backend/cuda/quantized/mxfp8_quantize.cuh
Normal file
@@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// TODO implement fast path
|
||||
template <typename T>
|
||||
__device__ __forceinline__ uint32_t
|
||||
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
|
||||
uint32_t out_fp8x4 = 0;
|
||||
float4 scaled;
|
||||
scaled.x = static_cast<float>(input.x) * scale;
|
||||
scaled.y = static_cast<float>(input.y) * scale;
|
||||
scaled.z = static_cast<float>(input.z) * scale;
|
||||
scaled.w = static_cast<float>(input.w) * scale;
|
||||
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
|
||||
return out_fp8x4;
|
||||
}
|
||||
|
||||
// Place holder for future fast path implementation
|
||||
template <typename T, bool USE_SR>
|
||||
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
|
||||
const Vector4_t<T> input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
|
||||
}
|
||||
} // namespace mlx::core::cu
|
||||
334
mlx/backend/cuda/quantized/nvfp4_quantize.cuh
Normal file
334
mlx/backend/cuda/quantized/nvfp4_quantize.cuh
Normal file
@@ -0,0 +1,334 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp4.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
||||
using fp16x4 = Vector4_t<__half>;
|
||||
using f32x4 = Vector4_t<float>;
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
|
||||
// Fallback implementation for architectures that do not support cvt
|
||||
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
|
||||
uint16_t out_fp4x4 = 0;
|
||||
fp32x4 scaled;
|
||||
scaled.x = static_cast<float>(input.x) * scale;
|
||||
scaled.y = static_cast<float>(input.y) * scale;
|
||||
scaled.z = static_cast<float>(input.z) * scale;
|
||||
scaled.w = static_cast<float>(input.w) * scale;
|
||||
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
|
||||
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
|
||||
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
|
||||
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
|
||||
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
|
||||
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
|
||||
static_cast<uint16_t>(q0);
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
||||
defined(__CUDA_ARCH_SPECIFIC__)
|
||||
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_bf16; \n\t" // first bf16
|
||||
".reg.b16 x1_bf16; \n\t" // second bf16
|
||||
".reg.b16 x2_bf16; \n\t" // third bf16
|
||||
".reg.b16 x3_bf16; \n\t" // fourth bf16
|
||||
".reg.b32 x0; \n\t" // to hold scaled first
|
||||
".reg.b32 x1; \n\t" // to hold scaled second
|
||||
".reg.b32 x2; \n\t" // to hold scaled third
|
||||
".reg.b32 x3; \n\t" // to hold scaled fourth
|
||||
".reg.b64 x01; \n\t" // to hold vector mul
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
|
||||
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
|
||||
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
|
||||
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
|
||||
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
||||
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
||||
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
|
||||
// pair
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
|
||||
// pair
|
||||
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(
|
||||
scale))); // here cast is needed becuase an asm operand must have
|
||||
// scalar type
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
|
||||
const bf16x4 input_bf16x4,
|
||||
const float2 scale,
|
||||
uint32_t rbits) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_bf16; \n\t"
|
||||
".reg.b16 x1_bf16; \n\t"
|
||||
".reg.b16 x2_bf16; \n\t"
|
||||
".reg.b16 x3_bf16; \n\t"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b16 q0; \n\t"
|
||||
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
|
||||
"cvt.f32.bf16 x0, x0_bf16; \n\t"
|
||||
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
||||
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
||||
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t"
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||
"r"(rbits));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
|
||||
const float2 input_fp32x2_0,
|
||||
const float2 input_fp32x2_1,
|
||||
const float2 scale) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b8 q0; \n\t"
|
||||
".reg.b8 q1; \n\t"
|
||||
"mov.b64 x01, {%1, %2}; \n\t"
|
||||
"mul.f32x2 x01, x01, %5; \n\t"
|
||||
"mov.b64 x23, {%3, %4}; \n\t"
|
||||
"mul.f32x2 x23, x23, %5; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
||||
"mov.b16 %0, {q0, q1}; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "f"(input_fp32x2_0.x),
|
||||
"f"(input_fp32x2_0.y),
|
||||
"f"(input_fp32x2_1.x),
|
||||
"f"(input_fp32x2_1.y),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
|
||||
const float2 input_fp32x2_0,
|
||||
const float2 input_fp32x2_1,
|
||||
const float2 scale,
|
||||
uint32_t rbits) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b16 q0; \n\t"
|
||||
"mov.b64 x01, {%1, %2}; \n\t"
|
||||
"mul.f32x2 x01, x01, %5; \n\t"
|
||||
"mov.b64 x23, {%3, %4}; \n\t"
|
||||
"mul.f32x2 x23, x23, %5; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "f"(input_fp32x2_0.x),
|
||||
"f"(input_fp32x2_0.y),
|
||||
"f"(input_fp32x2_1.x),
|
||||
"f"(input_fp32x2_1.y),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||
"r"(rbits));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_fp16; \n\t"
|
||||
".reg.b16 x1_fp16; \n\t"
|
||||
".reg.b16 x2_fp16; \n\t"
|
||||
".reg.b16 x3_fp16; \n\t"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b8 q0; \n\t"
|
||||
".reg.b8 q1; \n\t"
|
||||
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
||||
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
||||
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
||||
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
||||
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t"
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
||||
"mov.b16 %0, {q0, q1}; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
|
||||
const fp16x4 input_fp16x4,
|
||||
const float2 scale,
|
||||
uint32_t rbits) {
|
||||
uint16_t out_fp4x4 = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg.b16 x0_fp16; \n\t"
|
||||
".reg.b16 x1_fp16; \n\t"
|
||||
".reg.b16 x2_fp16; \n\t"
|
||||
".reg.b16 x3_fp16; \n\t"
|
||||
".reg.b32 x0; \n\t"
|
||||
".reg.b32 x1; \n\t"
|
||||
".reg.b32 x2; \n\t"
|
||||
".reg.b32 x3; \n\t"
|
||||
".reg.b64 x01; \n\t"
|
||||
".reg.b64 x23; \n\t"
|
||||
".reg.b16 q0; \n\t"
|
||||
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
||||
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
||||
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
||||
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
||||
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
||||
"mov.b64 x01, {x0, x1}; \n\t"
|
||||
"mul.f32x2 x01, x01, %2; \n\t"
|
||||
"mov.b64 x23, {x2, x3}; \n\t"
|
||||
"mul.f32x2 x23, x23, %2; \n\t"
|
||||
"mov.b64 {x0, x1}, x01; \n\t"
|
||||
"mov.b64 {x2, x3}, x23; \n\t"
|
||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
||||
"}"
|
||||
: "=h"(out_fp4x4)
|
||||
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||
"r"(rbits));
|
||||
return out_fp4x4;
|
||||
}
|
||||
|
||||
template <bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
|
||||
const bf16x4 input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||
if constexpr (USE_SR) {
|
||||
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
||||
} else {
|
||||
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
|
||||
const fp16x4 input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||
if constexpr (USE_SR) {
|
||||
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
||||
} else {
|
||||
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t
|
||||
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
|
||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||
float2 input_fp32x2_0 = make_float2(input.x, input.y);
|
||||
float2 input_fp32x2_1 = make_float2(input.z, input.w);
|
||||
|
||||
if constexpr (USE_SR) {
|
||||
return scale_cvt_fp32x4_to_fp4x4_rs(
|
||||
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
|
||||
} else {
|
||||
return scale_cvt_fp32x4_to_fp4x4_rn(
|
||||
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
|
||||
const Vector4_t<T> input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||
} else if constexpr (std::is_same<T, __half>::value) {
|
||||
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||
} else {
|
||||
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||
}
|
||||
}
|
||||
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
|
||||
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
||||
|
||||
template <typename T, bool USE_SR>
|
||||
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
|
||||
const Vector4_t<T> input,
|
||||
const float scale,
|
||||
uint32_t rbits) {
|
||||
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
||||
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
||||
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
|
||||
#else
|
||||
static_assert(
|
||||
!USE_SR,
|
||||
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
|
||||
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
|
||||
#endif
|
||||
}
|
||||
} // namespace mlx::core::cu
|
||||
@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
|
||||
if constexpr (
|
||||
(std::is_same<T, __nv_bfloat162>::value) ||
|
||||
(std::is_same<T, __half2>::value)) {
|
||||
T a = x1;
|
||||
T b = x2;
|
||||
out = __hmax2(__habs2(a), __habs2(b));
|
||||
} else if constexpr (std::is_same<T, float2>::value) {
|
||||
float2 a = x1;
|
||||
float2 b = x2;
|
||||
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
|
||||
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename F>
|
||||
|
||||
@@ -3,31 +3,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/steel/utils.cuh"
|
||||
#include "mlx/backend/cuda/vector_types.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Map types to their vector of 2 type float -> float2, double -> double2 etc
|
||||
template <typename T>
|
||||
struct Vector2;
|
||||
template <>
|
||||
struct Vector2<double> {
|
||||
using type = double2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<float> {
|
||||
using type = float2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<__half> {
|
||||
using type = __half2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<__nv_bfloat16> {
|
||||
using type = __nv_bfloat162;
|
||||
};
|
||||
template <typename T>
|
||||
using Vector2_t = typename Vector2<T>::type;
|
||||
|
||||
/**
|
||||
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
||||
* the warp.
|
||||
|
||||
48
mlx/backend/cuda/vector_types.cuh
Normal file
48
mlx/backend/cuda/vector_types.cuh
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct Vector2;
|
||||
|
||||
template <>
|
||||
struct Vector2<double> {
|
||||
using type = double2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vector2<float> {
|
||||
using type = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vector2<__half> {
|
||||
using type = __half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Vector2<__nv_bfloat16> {
|
||||
using type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using Vector2_t = typename Vector2<T>::type;
|
||||
|
||||
template <typename T>
|
||||
struct Vector4 {
|
||||
T x, y, z, w;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using Vector4_t = Vector4<T>;
|
||||
|
||||
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
||||
using fp16x4 = Vector4_t<__half>;
|
||||
using fp32x4 = Vector4_t<float>;
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -1,7 +1,7 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=80",
|
||||
"nanobind==2.4.0",
|
||||
"nanobind==2.10.2",
|
||||
"cmake>=3.25",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -89,7 +89,8 @@ static PyType_Spec gc_func_spec = {
|
||||
/* .name = */ "mlx.gc_func",
|
||||
/* .basicsize = */ (int)sizeof(gc_func),
|
||||
/* .itemsize = */ 0,
|
||||
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
|
||||
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
|
||||
Py_TPFLAGS_HAVE_VECTORCALL,
|
||||
/* .slots = */ gc_func_slots};
|
||||
|
||||
static PyTypeObject* gc_func_tp = nullptr;
|
||||
|
||||
@@ -16,8 +16,7 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
|
||||
|
||||
NB_TYPE_CASTER(
|
||||
List,
|
||||
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
|
||||
const_name(", ...]"))
|
||||
const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
|
||||
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
||||
size_t size;
|
||||
|
||||
@@ -4,12 +4,12 @@ import gc
|
||||
import inspect
|
||||
import io
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial, wraps
|
||||
from io import StringIO
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestCompile(mlx_tests.MLXTestCase):
|
||||
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
loss, grads = step(emb, w, x)
|
||||
mx.eval(loss, grads)
|
||||
|
||||
def test_compile_donates_input_buffer(self):
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def fun(x):
|
||||
return mx.sin(x) + 1
|
||||
|
||||
compiled_fn = mx.compile(fun)
|
||||
|
||||
input = mx.arange(16, dtype=mx.float32)
|
||||
mx.eval(input)
|
||||
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
|
||||
|
||||
out = compiled_fn(input)
|
||||
del input # Ensure the reference is dropped
|
||||
mx.eval(out)
|
||||
|
||||
self.assertEqual(
|
||||
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user