mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-29 04:51:13 +08:00
Compare commits
7 Commits
dfd0cc4a5a
...
ff30a58b11
Author | SHA1 | Date | |
---|---|---|---|
![]() |
ff30a58b11 | ||
![]() |
580776559b | ||
![]() |
a14aaa7c9d | ||
![]() |
a6d780154f | ||
![]() |
6871e2eeb7 | ||
![]() |
8402a2acf4 | ||
![]() |
fddb6933e1 |
@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common(
|
|||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||||
|
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||||
|
auto gx = (dim0 + bx - 1) / bx;
|
||||||
|
auto gy = (dim1 + by - 1) / by;
|
||||||
|
auto gz = (dim2 + bz - 1) / bz;
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common(
|
|||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
size_t divisor);
|
size_t divisor);
|
||||||
|
|
||||||
|
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
|
@ -32,6 +32,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
@ -113,7 +112,7 @@ __global__ void arg_reduce_general(
|
|||||||
|
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
constexpr uint32_t N_READS = 4;
|
constexpr uint32_t N_READS = 4;
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
dim3 block_dims{BLOCK_DIM, 1, 1};
|
||||||
auto kernel = &cu::arg_reduce_general<
|
auto kernel = &cu::arg_reduce_general<
|
||||||
InType,
|
InType,
|
||||||
cu::ArgMax<InType>,
|
cu::ArgMax<InType>,
|
||||||
|
@ -194,6 +194,13 @@ struct Power {
|
|||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
if (base.y == 0 && base.x == 0) {
|
||||||
|
if (isnan(exp.x) || isnan(exp.y)) {
|
||||||
|
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
||||||
|
return make_cuFloatComplex(nan, nan);
|
||||||
|
}
|
||||||
|
return make_cuFloatComplex(0.0, 0.0);
|
||||||
|
}
|
||||||
auto x_theta = atan2f(base.y, base.x);
|
auto x_theta = atan2f(base.y, base.x);
|
||||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||||
|
@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
|
#define INCLUDE_PREFIX "mlx/backend/cuda/device/"
|
||||||
|
|
||||||
constexpr const char* g_include_names[] = {
|
constexpr const char* g_include_names[] = {
|
||||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||||
|
@ -23,4 +23,11 @@ dim3 get_2d_grid_dims(
|
|||||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
|
||||||
|
auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);
|
||||||
|
auto [gx, gy, gz] = grid;
|
||||||
|
auto [bx, by, bz] = block;
|
||||||
|
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -121,6 +121,7 @@ dim3 get_2d_grid_dims(
|
|||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
size_t divisor);
|
size_t divisor);
|
||||||
|
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||||
|
|
||||||
// Return a block size that achieves maximum potential occupancy for kernel.
|
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@ -44,9 +45,12 @@ class MatMul {
|
|||||||
int64_t b_batch_stride) {
|
int64_t b_batch_stride) {
|
||||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
auto type = dtype_to_cuda_type(dtype);
|
auto scale_type = dtype_to_cuda_type(dtype);
|
||||||
|
if (dtype == bfloat16 || dtype == float16) {
|
||||||
|
scale_type = CUDA_R_32F;
|
||||||
|
}
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||||
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
@ -65,6 +69,7 @@ class MatMul {
|
|||||||
&op,
|
&op,
|
||||||
sizeof(cublasOperation_t)));
|
sizeof(cublasOperation_t)));
|
||||||
|
|
||||||
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
a_desc_ = create_matrix_layout(
|
a_desc_ = create_matrix_layout(
|
||||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||||
b_desc_ = create_matrix_layout(
|
b_desc_ = create_matrix_layout(
|
||||||
@ -187,17 +192,13 @@ class MatMul {
|
|||||||
private:
|
private:
|
||||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case uint8:
|
|
||||||
case uint16:
|
|
||||||
case int8:
|
|
||||||
case int16:
|
|
||||||
case int32:
|
|
||||||
return CUBLAS_COMPUTE_32I;
|
|
||||||
case float16:
|
case float16:
|
||||||
case bfloat16:
|
|
||||||
return CUBLAS_COMPUTE_16F;
|
|
||||||
case float32:
|
|
||||||
return CUBLAS_COMPUTE_32F;
|
return CUBLAS_COMPUTE_32F;
|
||||||
|
case bfloat16:
|
||||||
|
return CUBLAS_COMPUTE_32F;
|
||||||
|
case float32:
|
||||||
|
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||||
|
: CUBLAS_COMPUTE_32F;
|
||||||
case float64:
|
case float64:
|
||||||
case complex64:
|
case complex64:
|
||||||
return CUBLAS_COMPUTE_64F;
|
return CUBLAS_COMPUTE_64F;
|
||||||
@ -209,16 +210,6 @@ class MatMul {
|
|||||||
|
|
||||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case uint8:
|
|
||||||
return CUDA_R_8U;
|
|
||||||
case uint16:
|
|
||||||
return CUDA_R_16U;
|
|
||||||
case int8:
|
|
||||||
return CUDA_R_8I;
|
|
||||||
case int16:
|
|
||||||
return CUDA_R_16I;
|
|
||||||
case int32:
|
|
||||||
return CUDA_R_32I;
|
|
||||||
case float16:
|
case float16:
|
||||||
return CUDA_R_16F;
|
return CUDA_R_16F;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
|
@ -94,7 +94,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU_USE_FALLBACK(RoPE)
|
|
||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -12,6 +13,8 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
__constant__ constexpr uint32_t rotations[2][4] = {
|
__constant__ constexpr uint32_t rotations[2][4] = {
|
||||||
{13, 15, 26, 6},
|
{13, 15, 26, 6},
|
||||||
{17, 29, 16, 24}};
|
{17, 29, 16, 24}};
|
||||||
@ -47,27 +50,28 @@ __global__ void rbitsc(
|
|||||||
dim3 grid_dims,
|
dim3 grid_dims,
|
||||||
bool odd,
|
bool odd,
|
||||||
uint32_t bytes_per_key) {
|
uint32_t bytes_per_key) {
|
||||||
uint2 index{
|
auto grid = cg::this_grid();
|
||||||
blockIdx.x * blockDim.x + threadIdx.x,
|
uint thread_index = grid.thread_rank();
|
||||||
blockIdx.y * blockDim.y + threadIdx.y};
|
uint index_x = thread_index % grid_dims.x;
|
||||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
uint index_y = thread_index / grid_dims.x;
|
||||||
|
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kidx = 2 * index.x;
|
auto kidx = 2 * index_x;
|
||||||
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
||||||
auto half_size = grid_dims.y - odd;
|
auto half_size = grid_dims.y - odd;
|
||||||
out += index.x * bytes_per_key;
|
out += index_x * bytes_per_key;
|
||||||
bool drop_last = odd && (index.y == half_size);
|
bool drop_last = odd && (index_y == half_size);
|
||||||
auto bits = threefry2x32_hash(
|
auto bits = threefry2x32_hash(
|
||||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||||
size_t idx = size_t(index.y) << 2;
|
size_t idx = size_t(index_y) << 2;
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
out[idx + i] = bits.bytes[0][i];
|
out[idx + i] = bits.bytes[0][i];
|
||||||
}
|
}
|
||||||
if (!drop_last) {
|
if (!drop_last) {
|
||||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
int edge_bytes = (bytes_per_key % 4);
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
for (int i = 0; i < edge_bytes; ++i) {
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
out[idx + i] = bits.bytes[1][i];
|
out[idx + i] = bits.bytes[1][i];
|
||||||
@ -89,30 +93,31 @@ __global__ void rbits(
|
|||||||
int32_t ndim,
|
int32_t ndim,
|
||||||
const __grid_constant__ Shape key_shape,
|
const __grid_constant__ Shape key_shape,
|
||||||
const __grid_constant__ Strides key_strides) {
|
const __grid_constant__ Strides key_strides) {
|
||||||
uint2 index{
|
auto grid = cg::this_grid();
|
||||||
blockIdx.x * blockDim.x + threadIdx.x,
|
uint thread_index = grid.thread_rank();
|
||||||
blockIdx.y * blockDim.y + threadIdx.y};
|
uint index_x = thread_index % grid_dims.x;
|
||||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
uint index_y = thread_index / grid_dims.x;
|
||||||
|
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kidx = 2 * index.x;
|
auto kidx = 2 * index_x;
|
||||||
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
||||||
auto k2_elem =
|
auto k2_elem =
|
||||||
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
||||||
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
||||||
auto half_size = grid_dims.y - odd;
|
auto half_size = grid_dims.y - odd;
|
||||||
out += size_t(index.x) * bytes_per_key;
|
out += size_t(index_x) * bytes_per_key;
|
||||||
bool drop_last = odd && (index.y == half_size);
|
bool drop_last = odd && (index_y == half_size);
|
||||||
auto bits = threefry2x32_hash(
|
auto bits = threefry2x32_hash(
|
||||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||||
size_t idx = size_t(index.y) << 2;
|
size_t idx = size_t(index_y) << 2;
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
out[idx + i] = bits.bytes[0][i];
|
out[idx + i] = bits.bytes[0][i];
|
||||||
}
|
}
|
||||||
if (!drop_last) {
|
if (!drop_last) {
|
||||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
int edge_bytes = (bytes_per_key % 4);
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
for (int i = 0; i < edge_bytes; ++i) {
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
out[idx + i] = bits.bytes[1][i];
|
out[idx + i] = bits.bytes[1][i];
|
||||||
@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
dim3 grid_dims{num_keys, half_size + odd};
|
dim3 grid_dims{num_keys, half_size + odd};
|
||||||
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
|
int64_t total = grid_dims.x * grid_dims.y;
|
||||||
dim3 num_blocks{
|
int32_t threads_y = 1;
|
||||||
cuda::ceil_div(grid_dims.x, block_dims.x),
|
while ((total / threads_y) >= (1U << 31)) {
|
||||||
cuda::ceil_div(grid_dims.y, block_dims.y)};
|
threads_y *= 2;
|
||||||
|
}
|
||||||
|
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||||
|
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||||
if (keys.flags().row_contiguous) {
|
if (keys.flags().row_contiguous) {
|
||||||
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
|
cu::rbitsc<<<grid, block, 0, stream>>>(
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
odd,
|
odd,
|
||||||
bytes_per_key);
|
bytes_per_key);
|
||||||
} else {
|
} else {
|
||||||
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
|
cu::rbits<<<grid, block, 0, stream>>>(
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
|
385
mlx/backend/cuda/rope.cu
Normal file
385
mlx/backend/cuda/rope.cu
Normal file
@ -0,0 +1,385 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward>
|
||||||
|
__device__ void rope_single_impl(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
int32_t offset,
|
||||||
|
float inv_freq,
|
||||||
|
float scale,
|
||||||
|
int64_t stride,
|
||||||
|
uint2 pos,
|
||||||
|
uint2 dims) {
|
||||||
|
float L = scale * static_cast<float>(offset);
|
||||||
|
|
||||||
|
// Compute costheta, sintheta
|
||||||
|
float theta = L * inv_freq;
|
||||||
|
float costheta = cos(theta);
|
||||||
|
float sintheta = sin(theta);
|
||||||
|
|
||||||
|
// Compute the input and output indices
|
||||||
|
uint index_1, index_2;
|
||||||
|
if (traditional) {
|
||||||
|
index_1 = 2 * pos.x + pos.y * stride;
|
||||||
|
index_2 = index_1 + 1;
|
||||||
|
} else {
|
||||||
|
index_1 = pos.x + pos.y * stride;
|
||||||
|
index_2 = index_1 + dims.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and write the output
|
||||||
|
float x1 = static_cast<float>(in[index_1]);
|
||||||
|
float x2 = static_cast<float>(in[index_2]);
|
||||||
|
float rx1;
|
||||||
|
float rx2;
|
||||||
|
if (forward) {
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta;
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta;
|
||||||
|
} else {
|
||||||
|
rx1 = x2 * sintheta + x1 * costheta;
|
||||||
|
rx2 = x2 * costheta - x1 * sintheta;
|
||||||
|
}
|
||||||
|
out[index_1] = static_cast<T>(rx1);
|
||||||
|
out[index_2] = static_cast<T>(rx2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward>
|
||||||
|
__global__ void rope_single(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
const int32_t* offset,
|
||||||
|
float scale,
|
||||||
|
float base,
|
||||||
|
int64_t stride,
|
||||||
|
uint2 dims) {
|
||||||
|
uint2 pos = make_uint2(
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y);
|
||||||
|
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||||
|
float inv_freq = exp2(-d * base);
|
||||||
|
rope_single_impl<T, traditional, forward>(
|
||||||
|
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward>
|
||||||
|
__global__ void rope_single_freqs(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
const int32_t* offset,
|
||||||
|
const float* freqs,
|
||||||
|
float scale,
|
||||||
|
int64_t stride,
|
||||||
|
uint2 dims,
|
||||||
|
int64_t freq_stride) {
|
||||||
|
uint2 pos = make_uint2(
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y);
|
||||||
|
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||||
|
rope_single_impl<T, traditional, forward>(
|
||||||
|
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward, int N = 4>
|
||||||
|
__device__ void rope_impl(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
int offset,
|
||||||
|
float inv_freq,
|
||||||
|
float scale,
|
||||||
|
const cuda::std::array<int64_t, 3> strides,
|
||||||
|
const cuda::std::array<int64_t, 3> out_strides,
|
||||||
|
int64_t n_batch,
|
||||||
|
uint3 pos,
|
||||||
|
uint3 dims) {
|
||||||
|
float L = scale * static_cast<float>(pos.y + offset);
|
||||||
|
|
||||||
|
// Compute costheta, sintheta
|
||||||
|
float theta = L * inv_freq;
|
||||||
|
float costheta = cos(theta);
|
||||||
|
float sintheta = sin(theta);
|
||||||
|
|
||||||
|
// Compute the input and output indices
|
||||||
|
size_t in_index_1, in_index_2;
|
||||||
|
size_t out_index_1, out_index_2;
|
||||||
|
if (traditional) {
|
||||||
|
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
|
N * pos.z * out_strides[0];
|
||||||
|
out_index_2 = out_index_1 + 1;
|
||||||
|
in_index_1 =
|
||||||
|
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||||
|
in_index_2 = in_index_1 + strides[2];
|
||||||
|
} else {
|
||||||
|
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||||
|
N * pos.z * out_strides[0];
|
||||||
|
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||||
|
in_index_1 =
|
||||||
|
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||||
|
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||||
|
// Read and write the output
|
||||||
|
float x1 = static_cast<float>(in[in_index_1]);
|
||||||
|
float x2 = static_cast<float>(in[in_index_2]);
|
||||||
|
float rx1;
|
||||||
|
float rx2;
|
||||||
|
if (forward) {
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta;
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta;
|
||||||
|
} else {
|
||||||
|
rx1 = x2 * sintheta + x1 * costheta;
|
||||||
|
rx2 = x2 * costheta - x1 * sintheta;
|
||||||
|
}
|
||||||
|
out[out_index_1] = static_cast<T>(rx1);
|
||||||
|
out[out_index_2] = static_cast<T>(rx2);
|
||||||
|
in_index_1 += strides[0];
|
||||||
|
in_index_2 += strides[0];
|
||||||
|
out_index_1 += out_strides[0];
|
||||||
|
out_index_2 += out_strides[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward>
|
||||||
|
__global__ void rope(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
const int32_t* offset,
|
||||||
|
float scale,
|
||||||
|
float base,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||||
|
int64_t n_batch,
|
||||||
|
uint3 dims) {
|
||||||
|
uint3 pos = make_uint3(
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y,
|
||||||
|
blockIdx.z * blockDim.z + threadIdx.z);
|
||||||
|
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||||
|
float inv_freq = exp2(-d * base);
|
||||||
|
rope_impl<T, traditional, forward>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
*offset,
|
||||||
|
inv_freq,
|
||||||
|
scale,
|
||||||
|
strides,
|
||||||
|
out_strides,
|
||||||
|
n_batch,
|
||||||
|
pos,
|
||||||
|
dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward>
|
||||||
|
__global__ void rope_freqs(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
const int32_t* offset,
|
||||||
|
const float* freqs,
|
||||||
|
float scale,
|
||||||
|
float base,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||||
|
int64_t n_batch,
|
||||||
|
uint3 dims,
|
||||||
|
int64_t freq_stride) {
|
||||||
|
uint3 pos = make_uint3(
|
||||||
|
blockIdx.x * blockDim.x + threadIdx.x,
|
||||||
|
blockIdx.y * blockDim.y + threadIdx.y,
|
||||||
|
blockIdx.z * blockDim.z + threadIdx.z);
|
||||||
|
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||||
|
rope_impl<T, traditional, forward>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
*offset,
|
||||||
|
inv_freq,
|
||||||
|
scale,
|
||||||
|
strides,
|
||||||
|
out_strides,
|
||||||
|
n_batch,
|
||||||
|
pos,
|
||||||
|
dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool RoPE::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RoPE::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("RoPE::eval_gpu");
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& offset = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
if (in.ndim() < 3) {
|
||||||
|
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||||
|
}
|
||||||
|
|
||||||
|
cuda::std::array<int64_t, 3> strides;
|
||||||
|
cuda::std::array<int64_t, 3> out_strides;
|
||||||
|
bool donated = false;
|
||||||
|
int ndim = in.ndim();
|
||||||
|
int dispatch_ndim = in.ndim();
|
||||||
|
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||||
|
dispatch_ndim--;
|
||||||
|
}
|
||||||
|
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||||
|
|
||||||
|
// We apply rope to less that the whole vector so copy to output and then
|
||||||
|
// apply in-place.
|
||||||
|
if (dims_ < in.shape(-1)) {
|
||||||
|
donated = true;
|
||||||
|
auto ctype =
|
||||||
|
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||||
|
copy_gpu(in, out, ctype, s);
|
||||||
|
strides[0] = mat_size;
|
||||||
|
strides[1] = out.strides()[ndim - 2];
|
||||||
|
strides[2] = out.strides()[ndim - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Either copy or apply in-place
|
||||||
|
else if (in.flags().row_contiguous) {
|
||||||
|
if (in.is_donatable()) {
|
||||||
|
donated = true;
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
strides[0] = mat_size;
|
||||||
|
strides[1] = in.strides()[ndim - 2];
|
||||||
|
strides[2] = in.strides()[ndim - 1];
|
||||||
|
} else if (dispatch_ndim == 3) {
|
||||||
|
// Handle non-contiguous 3D inputs
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
strides[0] = in.strides()[ndim - 3];
|
||||||
|
strides[1] = in.strides()[ndim - 2];
|
||||||
|
strides[2] = in.strides()[ndim - 1];
|
||||||
|
} else {
|
||||||
|
// Copy non-contiguous > 3D inputs into the output and treat
|
||||||
|
// input as donated
|
||||||
|
donated = true;
|
||||||
|
copy_gpu(in, out, CopyType::General, s);
|
||||||
|
strides[0] = mat_size;
|
||||||
|
strides[1] = out.strides()[ndim - 2];
|
||||||
|
strides[2] = out.strides()[ndim - 1];
|
||||||
|
}
|
||||||
|
out_strides[0] = mat_size;
|
||||||
|
out_strides[1] = out.strides()[ndim - 2];
|
||||||
|
out_strides[2] = out.strides()[ndim - 1];
|
||||||
|
|
||||||
|
// Some flags to help us dispatch below
|
||||||
|
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||||
|
bool with_freqs = inputs.size() == 3;
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(donated ? out : in);
|
||||||
|
encoder.set_input_array(offset);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
||||||
|
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
||||||
|
if (single && !with_freqs) {
|
||||||
|
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
|
||||||
|
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||||
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
(donated ? out : in).data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
offset.data<int32_t>(),
|
||||||
|
scale_,
|
||||||
|
std::log2(base_),
|
||||||
|
mat_size,
|
||||||
|
dims);
|
||||||
|
} else if (single) {
|
||||||
|
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||||
|
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||||
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
(donated ? out : in).data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
offset.data<int32_t>(),
|
||||||
|
inputs[2].data<float>(),
|
||||||
|
scale_,
|
||||||
|
mat_size,
|
||||||
|
dims,
|
||||||
|
inputs[2].strides(0));
|
||||||
|
} else if (with_freqs) {
|
||||||
|
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||||
|
uint3 dims =
|
||||||
|
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||||
|
dims.z = (dims.z + 3) / 4;
|
||||||
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
(donated ? out : in).data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
offset.data<int32_t>(),
|
||||||
|
inputs[2].data<float>(),
|
||||||
|
scale_,
|
||||||
|
std::log2(base_),
|
||||||
|
strides,
|
||||||
|
out_strides,
|
||||||
|
in.size() / mat_size,
|
||||||
|
dims,
|
||||||
|
inputs[2].strides(0));
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
|
||||||
|
uint3 dims =
|
||||||
|
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||||
|
dims.z = (dims.z + 3) / 4;
|
||||||
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
(donated ? out : in).data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
offset.data<int32_t>(),
|
||||||
|
scale_,
|
||||||
|
std::log2(base_),
|
||||||
|
strides,
|
||||||
|
out_strides,
|
||||||
|
in.size() / mat_size,
|
||||||
|
dims);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||||
return steel_matmul_regular(
|
return steel_matmul_regular(
|
||||||
s,
|
/* const Stream& s = */ s,
|
||||||
d,
|
/* Device& d = */ d,
|
||||||
/* a = */ in_unfolded,
|
/* const array& a = */ in_unfolded,
|
||||||
/* b = */ wt_transpose,
|
/* const array& b = */ wt_transpose,
|
||||||
/* c = */ out,
|
/* array& c = */ out,
|
||||||
/* M = */ implicit_M,
|
/* int M = */ implicit_M,
|
||||||
/* N = */ implicit_N,
|
/* int N = */ implicit_N,
|
||||||
/* K = */ implicit_K,
|
/* int K = */ implicit_K,
|
||||||
/* batch_size_out = */ groups,
|
/* int batch_size_out = */ groups,
|
||||||
/* a_cols = */ implicit_K * groups,
|
/* int lda = */ implicit_K * groups,
|
||||||
/* b_cols = */ implicit_K,
|
/* int ldb = */ implicit_K,
|
||||||
/* out_cols = */ implicit_N * groups,
|
/* int ldd = */ implicit_N * groups,
|
||||||
/* a_transposed = */ false,
|
/* bool transpose_a = */ false,
|
||||||
/* b_transposed = */ true,
|
/* bool transpose_b = */ true,
|
||||||
/* batch_shape = */ {1},
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* batch_strides = */ {0},
|
/* Shape batch_shape = */ {1},
|
||||||
/* A_batch_strides = */ size_t(implicit_K),
|
/* Strides batch_strides = */ {0},
|
||||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
||||||
/* matrix_stride_out = */ size_t(implicit_N),
|
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
||||||
/*copies = */ copies);
|
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
||||||
}
|
}
|
||||||
|
|
||||||
void implicit_gemm_conv_2D_gpu(
|
void implicit_gemm_conv_2D_gpu(
|
||||||
|
@ -297,6 +297,9 @@ Device::Device() {
|
|||||||
device_ = load_device();
|
device_ = load_device();
|
||||||
default_library_ = load_default_library(device_);
|
default_library_ = load_default_library(device_);
|
||||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||||
|
int ag_tens = arch_[arch_.size() - 3] - '0';
|
||||||
|
int ag_ones = arch_[arch_.size() - 2] - '0';
|
||||||
|
arch_gen_ = ag_tens * 10 + ag_ones;
|
||||||
auto arch = arch_.back();
|
auto arch = arch_.back();
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'p': // phone
|
case 'p': // phone
|
||||||
|
@ -177,6 +177,10 @@ class Device {
|
|||||||
return arch_;
|
return arch_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int get_architecture_gen() const {
|
||||||
|
return arch_gen_;
|
||||||
|
}
|
||||||
|
|
||||||
void new_queue(int index);
|
void new_queue(int index);
|
||||||
|
|
||||||
MTL::CommandQueue* get_queue(Stream stream);
|
MTL::CommandQueue* get_queue(Stream stream);
|
||||||
@ -268,6 +272,7 @@ class Device {
|
|||||||
library_kernels_;
|
library_kernels_;
|
||||||
const MTL::ResidencySet* residency_set_{nullptr};
|
const MTL::ResidencySet* residency_set_{nullptr};
|
||||||
std::string arch_;
|
std::string arch_;
|
||||||
|
int arch_gen_;
|
||||||
int max_ops_per_buffer_;
|
int max_ops_per_buffer_;
|
||||||
int max_mb_per_buffer_;
|
int max_mb_per_buffer_;
|
||||||
};
|
};
|
||||||
|
@ -235,6 +235,13 @@ struct Power {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
if (x.real == 0 && x.imag == 0) {
|
||||||
|
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
|
||||||
|
auto nan = metal::numeric_limits<float>::quiet_NaN();
|
||||||
|
return {nan, nan};
|
||||||
|
}
|
||||||
|
return {0.0, 0.0};
|
||||||
|
}
|
||||||
auto x_theta = metal::atan2(x.imag, x.real);
|
auto x_theta = metal::atan2(x.imag, x.real);
|
||||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||||
|
@ -33,8 +33,8 @@ template <
|
|||||||
device T* D [[buffer(3)]],
|
device T* D [[buffer(3)]],
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||||
const constant int64_t* batch_strides [[buffer(7)]],
|
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -6,7 +6,34 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void steel_matmul_regular(
|
template <bool CHECK_AB = true>
|
||||||
|
void steel_matmul_regular_axpby(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
int ldd,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
|
Shape batch_shape,
|
||||||
|
Strides batch_strides,
|
||||||
|
int64_t A_batch_stride,
|
||||||
|
int64_t B_batch_stride,
|
||||||
|
int64_t matrix_stride_out,
|
||||||
|
int64_t C_batch_stride = 0,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f);
|
||||||
|
|
||||||
|
inline void steel_matmul_regular(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const array& a,
|
const array& a,
|
||||||
@ -21,14 +48,61 @@ void steel_matmul_regular(
|
|||||||
int ldd,
|
int ldd,
|
||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
Shape batch_shape,
|
Shape batch_shape,
|
||||||
Strides batch_strides,
|
Strides batch_strides,
|
||||||
int64_t A_batch_stride,
|
int64_t A_batch_stride,
|
||||||
int64_t B_batch_stride,
|
int64_t B_batch_stride,
|
||||||
int64_t matrix_stride_out,
|
int64_t matrix_stride_out) {
|
||||||
std::vector<array>& copies);
|
return steel_matmul_regular_axpby<false>(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
|
/* const array& c = */ b,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int batch_size_out = */ batch_size_out,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* int ldd = */ ldd,
|
||||||
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies,
|
||||||
|
/* Shape batch_shape = */ batch_shape,
|
||||||
|
/* Strides batch_strides = */ batch_strides,
|
||||||
|
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||||
|
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||||
|
/* int64_t matrix_stride_out = */ matrix_stride_out);
|
||||||
|
}
|
||||||
|
|
||||||
void steel_matmul(
|
template <bool CHECK_AB = true>
|
||||||
|
void steel_matmul_axpby(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
|
Shape batch_shape = {},
|
||||||
|
Strides A_batch_stride = {},
|
||||||
|
Strides B_batch_stride = {},
|
||||||
|
Strides C_batch_stride = {},
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f);
|
||||||
|
|
||||||
|
inline void steel_matmul(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const array& a,
|
const array& a,
|
||||||
@ -45,6 +119,26 @@ void steel_matmul(
|
|||||||
std::vector<array>& copies,
|
std::vector<array>& copies,
|
||||||
Shape batch_shape = {},
|
Shape batch_shape = {},
|
||||||
Strides A_batch_stride = {},
|
Strides A_batch_stride = {},
|
||||||
Strides B_batch_stride = {});
|
Strides B_batch_stride = {}) {
|
||||||
|
return steel_matmul_axpby<false>(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
|
/* const array& c = */ b,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int batch_size_out = */ batch_size_out,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies,
|
||||||
|
/* Shape batch_shape = */ batch_shape,
|
||||||
|
/* Strides A_batch_stride = */ A_batch_stride,
|
||||||
|
/* Strides B_batch_stride = */ B_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
|
|||||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
if (no_copy && x.ndim() > 1) {
|
if (no_copy && x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 2];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||||
}
|
}
|
||||||
if (no_copy) {
|
if (no_copy) {
|
||||||
if (x.is_donatable()) {
|
if (x.is_donatable()) {
|
||||||
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
|
|||||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
if (no_copy && x.ndim() > 1) {
|
if (no_copy && x.ndim() > 1) {
|
||||||
auto s = x.strides()[x.ndim() - 2];
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||||
}
|
}
|
||||||
if (no_copy) {
|
if (no_copy) {
|
||||||
if (x.is_donatable()) {
|
if (x.is_donatable()) {
|
||||||
|
51
mlx/ops.cpp
51
mlx/ops.cpp
@ -2847,21 +2847,6 @@ array matmul(
|
|||||||
"[matmul] Got 0 dimension input. Inputs must "
|
"[matmul] Got 0 dimension input. Inputs must "
|
||||||
"have at least one dimension.");
|
"have at least one dimension.");
|
||||||
}
|
}
|
||||||
if (a.ndim() == 1) {
|
|
||||||
// Insert a singleton dim in the beginning
|
|
||||||
a = expand_dims(a, 0, s);
|
|
||||||
}
|
|
||||||
if (b.ndim() == 1) {
|
|
||||||
// Insert a singleton dim at the end
|
|
||||||
b = expand_dims(b, 1, s);
|
|
||||||
}
|
|
||||||
if (a.shape(-1) != b.shape(-2)) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[matmul] Last dimension of first input with shape " << a.shape()
|
|
||||||
<< " must match second to last dimension of"
|
|
||||||
<< " second input with shape " << b.shape() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// complex matmul using Karatsuba's Algorithm
|
// complex matmul using Karatsuba's Algorithm
|
||||||
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
||||||
@ -2883,6 +2868,22 @@ array matmul(
|
|||||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (a.ndim() == 1) {
|
||||||
|
// Insert a singleton dim in the beginning
|
||||||
|
a = expand_dims(a, 0, s);
|
||||||
|
}
|
||||||
|
if (b.ndim() == 1) {
|
||||||
|
// Insert a singleton dim at the end
|
||||||
|
b = expand_dims(b, 1, s);
|
||||||
|
}
|
||||||
|
if (a.shape(-1) != b.shape(-2)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[matmul] Last dimension of first input with shape " << a.shape()
|
||||||
|
<< " must match second to last dimension of"
|
||||||
|
<< " second input with shape " << b.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
// Type promotion
|
// Type promotion
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
|
|
||||||
@ -4240,6 +4241,16 @@ array addmm(
|
|||||||
"have at least one dimension.");
|
"have at least one dimension.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type promotion
|
||||||
|
auto out_type = result_type(a, b, c);
|
||||||
|
|
||||||
|
if (out_type == complex64) {
|
||||||
|
return add(
|
||||||
|
multiply(matmul(a, b, s), array(alpha), s),
|
||||||
|
multiply(array(beta), c, s),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
if (a.ndim() == 1) {
|
if (a.ndim() == 1) {
|
||||||
// Insert a singleton dim in the beginning
|
// Insert a singleton dim in the beginning
|
||||||
a = expand_dims(a, 0, s);
|
a = expand_dims(a, 0, s);
|
||||||
@ -4257,16 +4268,6 @@ array addmm(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type promotion
|
|
||||||
auto out_type = result_type(a, b, c);
|
|
||||||
|
|
||||||
if (out_type == complex64) {
|
|
||||||
return add(
|
|
||||||
multiply(matmul(a, b, s), array(alpha), s),
|
|
||||||
multiply(array(beta), c, s),
|
|
||||||
s);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[addmm] Only real floating point types are supported but "
|
msg << "[addmm] Only real floating point types are supported but "
|
||||||
|
@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) {
|
|||||||
os << val;
|
os << val;
|
||||||
}
|
}
|
||||||
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||||
os << val;
|
os << val.real();
|
||||||
|
if (val.imag() >= 0 || std::isnan(val.imag())) {
|
||||||
|
os << "+" << val.imag() << "j";
|
||||||
|
} else {
|
||||||
|
os << "-" << -val.imag() << "j";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PrintFormatter& get_global_formatter() {
|
PrintFormatter& get_global_formatter() {
|
||||||
|
@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
|
|||||||
return metal_fast_synch;
|
return metal_fast_synch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool enable_tf32() {
|
||||||
|
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
|
||||||
|
return enable_tf32_;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
c_np = np.matmul(np.array(a).T, b)
|
c_np = np.matmul(np.array(a).T, b)
|
||||||
self.assertTrue(np.allclose(c, c_np))
|
self.assertTrue(np.allclose(c, c_np))
|
||||||
|
|
||||||
|
# Check shapes
|
||||||
|
a = mx.random.normal((2, 3)).astype(mx.complex64)
|
||||||
|
b = mx.random.normal((3,))
|
||||||
|
self.assertEqual((a @ b).shape, (2,))
|
||||||
|
|
||||||
|
a = mx.random.normal((2, 3)).astype(mx.complex64)
|
||||||
|
b = mx.random.normal((3,))
|
||||||
|
c = mx.random.normal((2,))
|
||||||
|
self.assertEqual(mx.addmm(c, a, b).shape, (2,))
|
||||||
|
|
||||||
def test_complex_gemm(self):
|
def test_complex_gemm(self):
|
||||||
M = 16
|
M = 16
|
||||||
K = 50
|
K = 50
|
||||||
|
@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
||||||
|
|
||||||
|
def test_complex_power(self):
|
||||||
|
out = mx.power(mx.array(0j), 2)
|
||||||
|
self.assertEqual(out.item(), 0j)
|
||||||
|
|
||||||
|
out = mx.power(mx.array(0j), float("nan"))
|
||||||
|
self.assertTrue(mx.isnan(out))
|
||||||
|
|
||||||
|
|
||||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||||
def test_broadcast_shapes(self):
|
def test_broadcast_shapes(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user