mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
6 Commits
v0.27.1
...
ef631d63af
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 |
@@ -212,22 +212,42 @@ jobs:
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||
- run:
|
||||
name: Install Python package
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --max-size 400MB
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
paths:
|
||||
- /home/circleci/.cache/ccache
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -555,6 +575,9 @@ workflows:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
|
||||
@@ -41,6 +41,7 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -68,6 +69,15 @@ else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
@@ -105,11 +105,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||
endif()
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||
# managed memory.
|
||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES "native")
|
||||
endif()
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
@@ -28,7 +28,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a[0], b[0]);
|
||||
out_vec[i] = Op{}(a[0], b[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -49,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
|
||||
out_vec[i] = Op{}(a[0], b_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -70,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
|
||||
out_vec[i] = Op{}(a_vec[i], b[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -92,7 +92,7 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -248,8 +248,7 @@ void binary_op_gpu_inplace(
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||
|
||||
@@ -33,8 +33,8 @@ binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a_vec.val[i] = out[0];
|
||||
out_b_vec.val[i] = out[1];
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
@@ -60,9 +60,9 @@ binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_b_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a[0], b_vec.val[i]);
|
||||
out_a_vec.val[i] = out[0];
|
||||
out_b_vec.val[i] = out[1];
|
||||
auto out = Op{}(a[0], b_vec[i]);
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
@@ -88,9 +88,9 @@ binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_b_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a_vec.val[i], b[0]);
|
||||
out_a_vec.val[i] = out[0];
|
||||
out_b_vec.val[i] = out[1];
|
||||
auto out = Op{}(a_vec[i], b[0]);
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
@@ -117,9 +117,9 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_b_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
|
||||
out_a_vec.val[i] = out[0];
|
||||
out_b_vec.val[i] = out[1];
|
||||
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
@@ -270,8 +270,7 @@ void binary_two_op_gpu_inplace(
|
||||
} else {
|
||||
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||
|
||||
@@ -22,7 +22,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = cast_to<Out>(in[0]);
|
||||
out_vec[i] = cast_to<Out>(in[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -43,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
|
||||
out_vec[i] = cast_to<Out>(in_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -65,8 +65,7 @@ void copy_contiguous(
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -54,6 +55,10 @@ Device::Device(int device) : device_(device) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||
// The cudnn handle is used by Convolution.
|
||||
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||
|
||||
// Initialize the jit module cache here ensures it is not
|
||||
// unloaded before any evaluation is done
|
||||
get_jit_module_cache();
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
@@ -92,23 +97,6 @@ CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
if (discard) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract and add as single kernel node when possible.
|
||||
size_t num_nodes;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||
if (num_nodes == 1) {
|
||||
cudaGraphNode_t captured_node;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
||||
if (type == cudaGraphNodeTypeKernel) {
|
||||
CUDA_KERNEL_NODE_PARAMS params;
|
||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||
enc.add_kernel_node(params);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Otherwise add the captured graph as subgraph.
|
||||
enc.add_graph_node(graph);
|
||||
}
|
||||
|
||||
|
||||
@@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
||||
#if __CUDA_ARCH__ < 900
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
|
||||
@@ -32,21 +32,103 @@ using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
template <typename T, int N>
|
||||
struct alignas(sizeof(T) * N) AlignedVector {
|
||||
T val[N];
|
||||
|
||||
__device__ T& operator[](int i) {
|
||||
return val[i];
|
||||
}
|
||||
|
||||
__device__ T operator[](int i) const {
|
||||
return val[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ bool is_aligned(T* x) {
|
||||
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ AlignedVector<T, N> load_vector(
|
||||
const T* ptr,
|
||||
uint32_t offset) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
if (is_aligned<N>(ptr)) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
} else {
|
||||
AlignedVector<T, N> v;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
v[i] = ptr[offset * N + i];
|
||||
}
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ AlignedVector<T, N>
|
||||
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
|
||||
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
} else {
|
||||
AlignedVector<T, N> v;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ AlignedVector<T, N> load_vector(
|
||||
const T* ptr,
|
||||
uint32_t offset,
|
||||
SizeT size,
|
||||
int64_t stride,
|
||||
T fallback) {
|
||||
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
} else {
|
||||
AlignedVector<T, N> v;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
v[i] =
|
||||
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ void
|
||||
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
if (is_aligned<N>(ptr)) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
ptr[offset * N + i] = vec[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ void store_vector(
|
||||
T* ptr,
|
||||
uint32_t offset,
|
||||
const AlignedVector<T, N>& vec,
|
||||
SizeT size) {
|
||||
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
} else {
|
||||
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
|
||||
ptr[offset * N + i] = vec[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for accessing strided data.
|
||||
|
||||
@@ -11,7 +11,6 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
static constexpr int n_per_thread = 4;
|
||||
static constexpr int rows_per_block = 8;
|
||||
|
||||
template <typename T, int rows_per_block, int n_per_thread>
|
||||
@@ -32,8 +31,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < n_per_thread; ++j) {
|
||||
sum += static_cast<float>(local_mat.val[j]) *
|
||||
static_cast<float>(local_vec.val[j]);
|
||||
sum +=
|
||||
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,8 +73,22 @@ __global__ void gemv_batched(
|
||||
}
|
||||
|
||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||
return K % (WARP_SIZE * n_per_thread) == 0 &&
|
||||
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_n_per_thread(int n_per_thread, F&& f) {
|
||||
switch (n_per_thread) {
|
||||
case 1:
|
||||
f(std::integral_constant<int, 1>{});
|
||||
break;
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void gemv(
|
||||
@@ -114,33 +127,39 @@ void gemv(
|
||||
rows = M;
|
||||
}
|
||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||
if (batch_count == 1) {
|
||||
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks_x,
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols);
|
||||
} else {
|
||||
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
dim3{num_blocks_x, batch_count},
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols,
|
||||
const_param(batch_shape),
|
||||
mat_strides,
|
||||
vec_strides,
|
||||
batch_shape.size());
|
||||
int n_per_t = 4;
|
||||
while (K % (n_per_t * WARP_SIZE) != 0) {
|
||||
n_per_t >>= 1;
|
||||
}
|
||||
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
|
||||
if (batch_count == 1) {
|
||||
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks_x,
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols);
|
||||
} else {
|
||||
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
dim3{num_blocks_x, batch_count},
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols,
|
||||
const_param(batch_shape),
|
||||
mat_strides,
|
||||
vec_strides,
|
||||
batch_shape.size());
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvrtc.h>
|
||||
@@ -330,11 +329,16 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
static std::unordered_map<std::string, JitModule> map;
|
||||
return map;
|
||||
}
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder) {
|
||||
static std::unordered_map<std::string, JitModule> map;
|
||||
auto& map = get_jit_module_cache();
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
|
||||
@@ -99,6 +99,8 @@ class JitModule {
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
|
||||
@@ -120,20 +120,6 @@ dim3 get_2d_grid_dims(
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
if constexpr (std::is_same_v<T, CUfunction>) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
} else {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
}
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
template <typename T>
|
||||
@@ -145,7 +131,7 @@ inline std::tuple<dim3, uint> get_launch_args(
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = max_occupancy_block_dim(kernel);
|
||||
uint block_dim = 1024;
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
|
||||
@@ -10,8 +10,6 @@
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -74,9 +72,11 @@ __global__ void layer_norm(
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
sum += static_cast<float>(xn[i]);
|
||||
}
|
||||
}
|
||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||
|
||||
@@ -87,11 +87,18 @@ __global__ void layer_norm(
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
if ((index + 1) * N_READS <= axis_size) {
|
||||
auto xn = load_vector<N_READS>(x, index);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
} else {
|
||||
for (int i = index * N_READS; i < axis_size; ++i) {
|
||||
float t = static_cast<float>(x[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
@@ -100,17 +107,15 @@ __global__ void layer_norm(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T bn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
store_vector<N_READS>(out, index, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,9 +148,11 @@ __global__ void layer_norm_vjp(
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
sum += static_cast<float>(xn[i]);
|
||||
}
|
||||
}
|
||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||
|
||||
@@ -155,19 +162,28 @@ __global__ void layer_norm_vjp(
|
||||
// Normalizer.
|
||||
float3 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
|
||||
if ((index + 1) * N_READS <= axis_size) {
|
||||
auto xn = load_vector<N_READS>(x, index);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
} else {
|
||||
for (int i = index * N_READS; i < axis_size; ++i) {
|
||||
float t = static_cast<float>(x[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||
@@ -179,12 +195,10 @@ __global__ void layer_norm_vjp(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
float wi = wn[i];
|
||||
@@ -194,9 +208,9 @@ __global__ void layer_norm_vjp(
|
||||
wn[i] = gi * xi;
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -257,9 +271,9 @@ void LayerNorm::eval_gpu(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -364,10 +378,10 @@ void LayerNormVJP::eval_gpu(
|
||||
encoder.set_output_array(gw_temp);
|
||||
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -57,7 +55,7 @@ __global__ void rms_norm(
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
uint32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
@@ -72,8 +70,8 @@ __global__ void rms_norm(
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
@@ -85,15 +83,14 @@ __global__ void rms_norm(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm);
|
||||
float y = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(y);
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
store_vector<N_READS>(out, index, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,13 +122,10 @@ __global__ void rms_norm_vjp(
|
||||
// Normalizer.
|
||||
float2 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
@@ -148,12 +142,9 @@ __global__ void rms_norm_vjp(
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
@@ -163,9 +154,9 @@ __global__ void rms_norm_vjp(
|
||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,9 +214,9 @@ void RMSNorm::eval_gpu(
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -312,11 +303,10 @@ void RMSNormVJP::eval_gpu(
|
||||
encoder.set_output_array(gw_temp);
|
||||
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
|
||||
@@ -32,7 +32,7 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
||||
AlignedVector<T, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
@@ -166,8 +166,7 @@ void ternary_op_gpu_inplace(
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int N_READS = 16 / sizeof(DType);
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
|
||||
@@ -30,7 +30,7 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec.val[i] = Op{}(in_vec.val[i]);
|
||||
out_vec[i] = Op{}(in_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
|
||||
@@ -3049,6 +3049,25 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.power(mx.array(0j), float("nan"))
|
||||
self.assertTrue(mx.isnan(out))
|
||||
|
||||
def test_irregular_alignments(self):
|
||||
# Unaligned unary op
|
||||
a = mx.ones((64, 1))
|
||||
b = -a[1:]
|
||||
self.assertTrue(mx.all(b == -1.0))
|
||||
|
||||
# Unaligned binary op
|
||||
a = mx.ones((64, 1))
|
||||
b = a[1:]
|
||||
c = b + b
|
||||
self.assertTrue(mx.all(c == 2.0))
|
||||
|
||||
# Unaligned ternary op
|
||||
a = mx.ones((64, 1))
|
||||
b = mx.zeros((63, 1))
|
||||
c = mx.ones((63, 1)).astype(mx.bool_)
|
||||
d = mx.where(c, a[1:], b)
|
||||
self.assertTrue(mx.all(d == 1.0))
|
||||
|
||||
|
||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
def test_broadcast_shapes(self):
|
||||
|
||||
15
setup.py
15
setup.py
@@ -44,6 +44,8 @@ def get_version():
|
||||
|
||||
|
||||
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
|
||||
build_macos = platform.system() == "Darwin"
|
||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||
|
||||
|
||||
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||
@@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
|
||||
"-DMLX_BUILD_EXAMPLES=OFF",
|
||||
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||
]
|
||||
if build_stage == 2 and build_cuda:
|
||||
# Last arch is always real and virtual for forward-compatibility
|
||||
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
|
||||
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
|
||||
|
||||
# Some generators require explcitly passing config when building.
|
||||
build_args = ["--config", cfg]
|
||||
# Adding CMake arguments set as environment variable
|
||||
@@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
|
||||
# Pass version to C++
|
||||
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
if build_macos:
|
||||
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||
if archs:
|
||||
@@ -113,6 +120,9 @@ class CMakeBuild(build_ext):
|
||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||
build_args += [f"-j{os.cpu_count()}"]
|
||||
|
||||
# Avoid cache miss when building from temporary dirs.
|
||||
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp)
|
||||
|
||||
subprocess.run(
|
||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||
)
|
||||
@@ -202,9 +212,6 @@ if __name__ == "__main__":
|
||||
],
|
||||
)
|
||||
|
||||
build_macos = platform.system() == "Darwin"
|
||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||
|
||||
version = get_version()
|
||||
|
||||
_setup = partial(
|
||||
|
||||
Reference in New Issue
Block a user