mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
9 Commits
d1f4d291e8
...
v0.27.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 |
@@ -203,8 +203,12 @@ jobs:
|
|||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
cuda_build_and_test:
|
cuda_build_and_test:
|
||||||
|
parameters:
|
||||||
|
image_date:
|
||||||
|
type: string
|
||||||
|
default: "2023.11.1"
|
||||||
machine:
|
machine:
|
||||||
image: linux-cuda-12:2023.11.1
|
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||||
resource_class: gpu.nvidia.small.gen2
|
resource_class: gpu.nvidia.small.gen2
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
@@ -212,6 +216,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
python3 -m venv env
|
python3 -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
@@ -381,7 +386,7 @@ jobs:
|
|||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt install cuda-toolkit-12-9
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install zip
|
sudo apt-get install zip
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
@@ -419,7 +424,10 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
- cuda_build_and_test
|
- cuda_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
|
|||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
@@ -68,18 +68,23 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "mlx[cuda]"
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "mlx[cpu]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.9
|
||||||
@@ -26,13 +26,22 @@ To install from PyPI you must meet the following requirements:
|
|||||||
CUDA
|
CUDA
|
||||||
^^^^
|
^^^^
|
||||||
|
|
||||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
MLX has a CUDA backend which you can install with:
|
||||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install "mlx[cuda]"
|
pip install "mlx[cuda]"
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
CPU-only (Linux)
|
CPU-only (Linux)
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
@@ -42,6 +51,13 @@ For a CPU-only version of MLX that runs on Linux use:
|
|||||||
|
|
||||||
pip install "mlx[cpu]"
|
pip install "mlx[cpu]"
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,14 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
@@ -46,6 +48,14 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
|
target_sources(
|
||||||
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
||||||
|
else()
|
||||||
|
target_sources(
|
||||||
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|
||||||
# Embed kernel sources in binary for JIT compilation.
|
# Embed kernel sources in binary for JIT compilation.
|
||||||
@@ -131,6 +141,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
|||||||
# Use NVRTC and driver APIs.
|
# Use NVRTC and driver APIs.
|
||||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||||
|
|
||||||
|
# Use the frontend APIs of cuDNN.
|
||||||
|
FetchContent_Declare(
|
||||||
|
cudnn
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||||
|
GIT_TAG v1.12.1
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||||
|
FetchContent_MakeAvailable(cudnn)
|
||||||
|
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
||||||
|
# Link with the actual cuDNN libraries.
|
||||||
|
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||||
|
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
|||||||
340
mlx/backend/cuda/conv.cpp
Normal file
340
mlx/backend/cuda/conv.cpp
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
// cudnn_frontend.h redefines this macro.
|
||||||
|
#undef CHECK_CUDA_ERROR
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
#include <cudnn_frontend_find_plan.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Not all engines support it so can not use this API now.
|
||||||
|
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||||
|
|
||||||
|
struct ConvCacheKey {
|
||||||
|
int device_id;
|
||||||
|
cudnnBackendDescriptorType_t backend_type;
|
||||||
|
cudnnDataType_t cudnn_type;
|
||||||
|
std::array<int, MAX_NDIM> input_shape;
|
||||||
|
std::array<int, MAX_NDIM> filter_shape;
|
||||||
|
std::array<int, MAX_NDIM> padding_lo;
|
||||||
|
std::array<int, MAX_NDIM> padding_hi;
|
||||||
|
std::array<int, MAX_NDIM> stride;
|
||||||
|
std::array<int, MAX_NDIM> dilation;
|
||||||
|
int groups;
|
||||||
|
uint8_t input_alignment;
|
||||||
|
uint8_t filter_alignment;
|
||||||
|
uint8_t output_alignment;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto& conv_cache() {
|
||||||
|
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
||||||
|
/* capacity */ 128);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||||
|
return std::vector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
|
||||||
|
if (vec.size() > MAX_NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||||
|
}
|
||||||
|
std::array<T, MAX_NDIM> result = {};
|
||||||
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nhwc_to_nchw(const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
|
shape.erase(shape.end() - 1);
|
||||||
|
auto strides = convert_vector<int64_t>(x.strides());
|
||||||
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
|
strides.erase(strides.end() - 1);
|
||||||
|
return std::make_tuple(shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return CUDNN_DATA_INT8;
|
||||||
|
case int32:
|
||||||
|
return CUDNN_DATA_INT32;
|
||||||
|
case uint8:
|
||||||
|
return CUDNN_DATA_UINT8;
|
||||||
|
case float16:
|
||||||
|
return CUDNN_DATA_HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDNN_DATA_BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return CUDNN_DATA_FLOAT;
|
||||||
|
case float64:
|
||||||
|
return CUDNN_DATA_DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint8_t get_alignment(const array& x) {
|
||||||
|
uint8_t alignment = 1;
|
||||||
|
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||||
|
for (; alignment < 32; alignment *= 2) {
|
||||||
|
if (address % (alignment * 2)) {
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||||
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(shape.size(), shape.data())
|
||||||
|
.setStrides(strides.size(), strides.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(get_alignment(x))
|
||||||
|
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph,
|
||||||
|
bool use_fallback = false) {
|
||||||
|
cudnn_frontend::GeneratorSource source;
|
||||||
|
if (use_fallback) {
|
||||||
|
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setOperation(backend_type)
|
||||||
|
.build();
|
||||||
|
return fallback.getFallbackList();
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||||
|
.build();
|
||||||
|
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||||
|
auto configs = generator.generate_engine_config(op_graph);
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList filtered_configs;
|
||||||
|
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||||
|
if (cudnn_frontend::hasNumericalNote<
|
||||||
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||||
|
dtype == float32 && !env::enable_tf32()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return filtered_configs;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool execute_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out) {
|
||||||
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
|
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||||
|
|
||||||
|
int64_t uids[3] = {'x', 'w', 'y'};
|
||||||
|
void* data_ptrs[3] = {
|
||||||
|
const_cast<void*>(in.data<void>()),
|
||||||
|
const_cast<void*>(wt.data<void>()),
|
||||||
|
out.data<void>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||||
|
.setWorkspacePointer(workspace.data<void>())
|
||||||
|
.setDataPointers(3, data_ptrs)
|
||||||
|
.setUids(3, uids)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||||
|
cudaGraph_t graph;
|
||||||
|
cudaGraphCreate(&graph, 0);
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||||
|
if (cudnnBackendPopulateCudaGraph(
|
||||||
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
encoder.add_graph_node(graph);
|
||||||
|
#else
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
if (cudnnBackendExecute(
|
||||||
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
// Discard the captured graph when failed.
|
||||||
|
capture.discard = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool try_engines(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::EngineConfigList& configs,
|
||||||
|
const ConvCacheKey& cache_key,
|
||||||
|
const std::string& op_graph_tag,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out) {
|
||||||
|
for (auto& config : configs) {
|
||||||
|
try {
|
||||||
|
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setEngineConfig(config, op_graph_tag)
|
||||||
|
.build();
|
||||||
|
if (execute_plan(encoder, plan, in, wt, out)) {
|
||||||
|
conv_cache().emplace(cache_key, std::move(plan));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} catch (cudnn_frontend::cudnnException&) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
array in = inputs[0];
|
||||||
|
array wt = inputs[1];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// cuDNN requires contiguous input.
|
||||||
|
// TODO: Handle NCHW format specially.
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_input_array(wt);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
||||||
|
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
|
||||||
|
|
||||||
|
// Search cache.
|
||||||
|
ConvCacheKey cache_key{
|
||||||
|
encoder.device().cuda_device(),
|
||||||
|
backend_type,
|
||||||
|
cudnn_type,
|
||||||
|
fixed_vector(in.shape()),
|
||||||
|
fixed_vector(wt.shape()),
|
||||||
|
fixed_vector(padding_lo_),
|
||||||
|
fixed_vector(padding_hi_),
|
||||||
|
fixed_vector(kernel_strides_),
|
||||||
|
fixed_vector(kernel_dilation_),
|
||||||
|
groups_,
|
||||||
|
get_alignment(in),
|
||||||
|
get_alignment(wt),
|
||||||
|
get_alignment(out)};
|
||||||
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
|
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
||||||
|
throw std::runtime_error("Cached convolution plan failed to execute.");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build operation graph.
|
||||||
|
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
||||||
|
? CUDNN_DATA_FLOAT
|
||||||
|
: cudnn_type;
|
||||||
|
|
||||||
|
auto stride = convert_vector<int64_t>(kernel_strides_);
|
||||||
|
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||||
|
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||||
|
auto dilation = convert_vector<int64_t>(kernel_dilation_);
|
||||||
|
|
||||||
|
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||||
|
.setDataType(compute_data_type)
|
||||||
|
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||||
|
.setNDims(stride.size())
|
||||||
|
.setStrides(stride.size(), stride.data())
|
||||||
|
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||||
|
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||||
|
.setDilation(dilation.size(), dilation.data())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
|
.setxDesc(build_tensor('x', in))
|
||||||
|
.setwDesc(build_tensor('w', wt))
|
||||||
|
.setyDesc(build_tensor('y', out))
|
||||||
|
.setcDesc(conv_desc)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||||
|
auto op_graph = cudnn_frontend::OperationGraphBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setOperationGraph(ops.size(), ops.data())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Try to run plans based on heuristics.
|
||||||
|
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||||
|
auto op_graph_tag = op_graph.getTag();
|
||||||
|
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Then try fallback plans.
|
||||||
|
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||||
|
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Unable to find an engine for convolution.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -9,12 +9,23 @@
|
|||||||
#include <future>
|
#include <future>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||||
// This should be less than 255
|
// This should be less than 255
|
||||||
constexpr int default_max_nodes_per_graph = 20;
|
constexpr int default_max_nodes_per_graph = 20;
|
||||||
|
|
||||||
|
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||||
|
if (err != CUDNN_STATUS_SUCCESS) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int cuda_graph_cache_size() {
|
int cuda_graph_cache_size() {
|
||||||
static int cache_size = []() {
|
static int cache_size = []() {
|
||||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||||
@@ -22,7 +33,7 @@ int cuda_graph_cache_size() {
|
|||||||
return cache_size;
|
return cache_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cu {
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
@@ -40,11 +51,14 @@ Device::Device(int device) : device_(device) {
|
|||||||
}
|
}
|
||||||
// The cublasLt handle is used by matmul.
|
// The cublasLt handle is used by matmul.
|
||||||
make_current();
|
make_current();
|
||||||
cublasLtCreate(<_);
|
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||||
|
// The cudnn handle is used by Convolution.
|
||||||
|
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
cublasLtDestroy(lt_);
|
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::make_current() {
|
void Device::make_current() {
|
||||||
@@ -66,29 +80,36 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
|
enc.device().make_current();
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||||
|
if (discard) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and add as single kernel node when possible.
|
||||||
size_t num_nodes;
|
size_t num_nodes;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||||
if (num_nodes == 1) {
|
if (num_nodes == 1) {
|
||||||
cudaGraphNode_t captured_node;
|
cudaGraphNode_t captured_node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||||
CUDA_KERNEL_NODE_PARAMS params;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
||||||
cudaGraphNode_t node;
|
if (type == cudaGraphNodeTypeKernel) {
|
||||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms));
|
CUDA_KERNEL_NODE_PARAMS params;
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'K'});
|
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||||
} else {
|
enc.add_kernel_node(params);
|
||||||
cudaGraphNode_t node;
|
return;
|
||||||
CHECK_CUDA_ERROR(
|
}
|
||||||
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
|
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'G'});
|
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
// Otherwise add the captured graph as subgraph.
|
||||||
|
enc.add_graph_node(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||||
@@ -221,10 +242,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
kernel_params.blockDim = block_dim;
|
kernel_params.blockDim = block_dim;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
cudaGraphNode_t node;
|
add_kernel_node(kernel_params);
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(
|
void CommandEncoder::add_kernel_node(
|
||||||
@@ -241,12 +259,27 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.blockDimY = block_dim.y;
|
kernel_params.blockDimY = block_dim.y;
|
||||||
kernel_params.blockDimZ = block_dim.z;
|
kernel_params.blockDimZ = block_dim.z;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
CUgraphNode node;
|
add_kernel_node(kernel_params);
|
||||||
CHECK_CUDA_ERROR(
|
}
|
||||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
|
||||||
|
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||||
|
CUgraphNode node;
|
||||||
|
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::commit() {
|
void CommandEncoder::commit() {
|
||||||
if (!temporaries_.empty()) {
|
if (!temporaries_.empty()) {
|
||||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||||
@@ -331,6 +364,4 @@ CommandEncoder& get_command_encoder(Stream s) {
|
|||||||
return device(s.device).get_command_encoder(s);
|
return device(s.device).get_command_encoder(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace mlx::core::cu
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
|
#include <cudnn.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@@ -21,6 +22,7 @@ class CommandEncoder {
|
|||||||
~CaptureContext();
|
~CaptureContext();
|
||||||
cudaGraph_t graph;
|
cudaGraph_t graph;
|
||||||
CommandEncoder& enc;
|
CommandEncoder& enc;
|
||||||
|
bool discard{false};
|
||||||
};
|
};
|
||||||
struct ConcurrentContext {
|
struct ConcurrentContext {
|
||||||
ConcurrentContext(CommandEncoder& enc);
|
ConcurrentContext(CommandEncoder& enc);
|
||||||
@@ -65,6 +67,11 @@ class CommandEncoder {
|
|||||||
void
|
void
|
||||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||||
|
|
||||||
|
// Low-level graph helpers.
|
||||||
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||||
|
void add_graph_node(cudaGraph_t child);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
temporaries_.push_back(arr.data_shared_ptr());
|
temporaries_.push_back(arr.data_shared_ptr());
|
||||||
}
|
}
|
||||||
@@ -73,6 +80,10 @@ class CommandEncoder {
|
|||||||
void maybe_commit();
|
void maybe_commit();
|
||||||
void commit();
|
void commit();
|
||||||
|
|
||||||
|
Device& device() {
|
||||||
|
return device_;
|
||||||
|
}
|
||||||
|
|
||||||
CudaStream& stream() {
|
CudaStream& stream() {
|
||||||
return stream_;
|
return stream_;
|
||||||
}
|
}
|
||||||
@@ -137,12 +148,16 @@ class Device {
|
|||||||
cublasLtHandle_t lt_handle() const {
|
cublasLtHandle_t lt_handle() const {
|
||||||
return lt_;
|
return lt_;
|
||||||
}
|
}
|
||||||
|
cudnnHandle_t cudnn_handle() const {
|
||||||
|
return cudnn_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_;
|
int device_;
|
||||||
int compute_capability_major_;
|
int compute_capability_major_;
|
||||||
int compute_capability_minor_;
|
int compute_capability_minor_;
|
||||||
cublasLtHandle_t lt_;
|
cublasLtHandle_t lt_;
|
||||||
|
cudnnHandle_t cudnn_;
|
||||||
std::unordered_map<int, CommandEncoder> encoders_;
|
std::unordered_map<int, CommandEncoder> encoders_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
73
mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp
Normal file
73
mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
void Matmul::run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides) {
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
|
||||||
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
|
auto concurrent = encoder.concurrent_context();
|
||||||
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
|
run_impl(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
|
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||||
|
nullptr);
|
||||||
|
a_it.step();
|
||||||
|
b_it.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Matmul::run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides,
|
||||||
|
const mlx::core::Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
|
||||||
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
|
auto concurrent = encoder.concurrent_context();
|
||||||
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
|
run_impl(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
|
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||||
|
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
a_it.step();
|
||||||
|
b_it.step();
|
||||||
|
c_it.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
206
mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
Normal file
206
mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
__global__ void set_mm_device_pointers(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ Shape batch_shape,
|
||||||
|
const __grid_constant__ Strides a_batch_strides,
|
||||||
|
const __grid_constant__ Strides b_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_ndim,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset] = elem_to_loc(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data(),
|
||||||
|
batch_ndim);
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void set_addmm_device_pointers(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* c_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ Shape batch_shape,
|
||||||
|
const __grid_constant__ Strides a_batch_strides,
|
||||||
|
const __grid_constant__ Strides b_batch_strides,
|
||||||
|
const __grid_constant__ Strides c_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_ndim,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset, c_offset] = elem_to_loc(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data(),
|
||||||
|
c_batch_strides.data(),
|
||||||
|
batch_ndim);
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||||
|
pointers[index + 3 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||||
|
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||||
|
&batch_mode,
|
||||||
|
sizeof(batch_mode)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Matmul::run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides) {
|
||||||
|
auto batch_count = out.size() / (M_ * N_);
|
||||||
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
|
set_pointer_mode(out_desc_, batch_count);
|
||||||
|
|
||||||
|
// Launch kernel to set device offsets
|
||||||
|
auto pointers = array(
|
||||||
|
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
|
||||||
|
{static_cast<int>(batch_count * 3)},
|
||||||
|
uint64);
|
||||||
|
|
||||||
|
encoder.add_temporary(pointers);
|
||||||
|
int block_size = 512;
|
||||||
|
encoder.set_output_array(pointers);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_mm_device_pointers,
|
||||||
|
cuda::ceil_div(pointers.size(), block_size),
|
||||||
|
block_size,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
static_cast<int>(out.dtype().size()),
|
||||||
|
const_param(batch_shape),
|
||||||
|
const_param(a_batch_strides),
|
||||||
|
const_param(b_batch_strides),
|
||||||
|
static_cast<int64_t>(M_) * N_,
|
||||||
|
static_cast<int>(batch_shape.size()),
|
||||||
|
batch_count);
|
||||||
|
|
||||||
|
// Run matmul
|
||||||
|
encoder.set_input_array(pointers);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto a_pointers = pointers.data<int8_t*>();
|
||||||
|
auto b_pointers = a_pointers + batch_count;
|
||||||
|
auto out_pointers = b_pointers + batch_count;
|
||||||
|
run_impl(
|
||||||
|
encoder,
|
||||||
|
reinterpret_cast<void*>(out_pointers),
|
||||||
|
reinterpret_cast<void*>(a_pointers),
|
||||||
|
reinterpret_cast<void*>(b_pointers),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Matmul::run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides,
|
||||||
|
const mlx::core::Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
auto batch_count = out.size() / (M_ * N_);
|
||||||
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
|
set_pointer_mode(c_desc_, batch_count);
|
||||||
|
set_pointer_mode(out_desc_, batch_count);
|
||||||
|
|
||||||
|
// Launch kernel to set device offsets
|
||||||
|
auto pointers = array(
|
||||||
|
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||||
|
{static_cast<int>(batch_count * 4)},
|
||||||
|
uint64);
|
||||||
|
|
||||||
|
encoder.add_temporary(pointers);
|
||||||
|
int block_size = 512;
|
||||||
|
encoder.set_output_array(pointers);
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_addmm_device_pointers,
|
||||||
|
cuda::ceil_div(pointers.size(), block_size),
|
||||||
|
block_size,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
static_cast<int>(out.dtype().size()),
|
||||||
|
const_param(batch_shape),
|
||||||
|
const_param(a_batch_strides),
|
||||||
|
const_param(b_batch_strides),
|
||||||
|
const_param(c_batch_strides),
|
||||||
|
static_cast<int64_t>(M_) * N_,
|
||||||
|
static_cast<int>(batch_shape.size()),
|
||||||
|
batch_count);
|
||||||
|
|
||||||
|
// Run matmul
|
||||||
|
encoder.set_input_array(pointers);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto a_pointers = pointers.data<int8_t*>();
|
||||||
|
auto b_pointers = a_pointers + batch_count;
|
||||||
|
auto c_pointers = b_pointers + batch_count;
|
||||||
|
auto out_pointers = c_pointers + batch_count;
|
||||||
|
run_impl(
|
||||||
|
encoder,
|
||||||
|
reinterpret_cast<void*>(out_pointers),
|
||||||
|
reinterpret_cast<void*>(a_pointers),
|
||||||
|
reinterpret_cast<void*>(b_pointers),
|
||||||
|
reinterpret_cast<void*>(c_pointers),
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
282
mlx/backend/cuda/gemms/cublas_gemm.cpp
Normal file
282
mlx/backend/cuda/gemms/cublas_gemm.cpp
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
struct CublasPreference {
|
||||||
|
CublasPreference(Device& device) {
|
||||||
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
// for Hopper+:
|
||||||
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
|
uint64_t MiB = 1024 * 1024;
|
||||||
|
uint64_t workspace_size =
|
||||||
|
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||||
|
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||||
|
pref_,
|
||||||
|
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||||
|
&workspace_size,
|
||||||
|
sizeof(uint64_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
~CublasPreference() {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||||
|
static CublasPreference pref(device);
|
||||||
|
return pref.pref_;
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case float16:
|
||||||
|
return CUBLAS_COMPUTE_32F;
|
||||||
|
case bfloat16:
|
||||||
|
return CUBLAS_COMPUTE_32F;
|
||||||
|
case float32:
|
||||||
|
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||||
|
: CUBLAS_COMPUTE_32F;
|
||||||
|
case float64:
|
||||||
|
case complex64:
|
||||||
|
return CUBLAS_COMPUTE_64F;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case float16:
|
||||||
|
return CUDA_R_16F;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDA_R_16BF;
|
||||||
|
case float32:
|
||||||
|
return CUDA_R_32F;
|
||||||
|
case float64:
|
||||||
|
return CUDA_R_64F;
|
||||||
|
case complex64:
|
||||||
|
return CUDA_C_32F;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatrixLayout_t create_matrix_layout(
|
||||||
|
cudaDataType_t type,
|
||||||
|
uint64_t rows,
|
||||||
|
uint64_t cols,
|
||||||
|
bool transposed,
|
||||||
|
int64_t ld,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t batch_stride) {
|
||||||
|
cublasLtMatrixLayout_t desc;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||||
|
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||||
|
if (batch_count > 1) {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&batch_count,
|
||||||
|
sizeof(int32_t)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||||
|
&batch_stride,
|
||||||
|
sizeof(int64_t)));
|
||||||
|
}
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
Matmul::Matmul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride)
|
||||||
|
: handle_(device.lt_handle()),
|
||||||
|
pref_(cublas_preference(device)),
|
||||||
|
M_(a_rows),
|
||||||
|
N_(b_cols) {
|
||||||
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
|
auto scale_type = dtype_to_cublas_type(dtype);
|
||||||
|
if (dtype == bfloat16 || dtype == float16) {
|
||||||
|
scale_type = CUDA_R_32F;
|
||||||
|
}
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||||
|
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||||
|
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||||
|
&pointer_mode,
|
||||||
|
sizeof(int32_t)));
|
||||||
|
cublasOperation_t op = CUBLAS_OP_N;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||||
|
&op,
|
||||||
|
sizeof(cublasOperation_t)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||||
|
&op,
|
||||||
|
sizeof(cublasOperation_t)));
|
||||||
|
|
||||||
|
auto type = dtype_to_cublas_type(dtype);
|
||||||
|
a_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||||
|
b_desc_ = create_matrix_layout(
|
||||||
|
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
|
out_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
Matmul::Matmul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
int64_t ldc,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride,
|
||||||
|
int64_t c_batch_stride)
|
||||||
|
: Matmul(
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
a_transposed,
|
||||||
|
a_rows,
|
||||||
|
a_cols,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
b_rows,
|
||||||
|
b_cols,
|
||||||
|
ldb,
|
||||||
|
batch_count,
|
||||||
|
a_batch_stride,
|
||||||
|
b_batch_stride) {
|
||||||
|
auto type = dtype_to_cublas_type(dtype);
|
||||||
|
c_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
Matmul::~Matmul() {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Matmul::run_impl(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
void* out,
|
||||||
|
const void* a,
|
||||||
|
const void* b,
|
||||||
|
const void* c,
|
||||||
|
float alpha /* = 1 */,
|
||||||
|
float beta /* = 0 */) {
|
||||||
|
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
int ret = 0;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
|
||||||
|
handle_,
|
||||||
|
matmul_desc_,
|
||||||
|
a_desc_,
|
||||||
|
b_desc_,
|
||||||
|
out_desc_, // TODO should that be c_desc is it's set?
|
||||||
|
out_desc_,
|
||||||
|
pref_,
|
||||||
|
1,
|
||||||
|
&heuristic_,
|
||||||
|
&ret));
|
||||||
|
if (ret == 0) {
|
||||||
|
throw std::runtime_error("Can not find algorithm for matmul.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void* workspace_ptr = nullptr;
|
||||||
|
if (heuristic_.workspaceSize > 0) {
|
||||||
|
array workspace(
|
||||||
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
|
int8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
workspace_ptr = workspace.data<void>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
|
handle_,
|
||||||
|
matmul_desc_,
|
||||||
|
&alpha,
|
||||||
|
a,
|
||||||
|
a_desc_,
|
||||||
|
b,
|
||||||
|
b_desc_,
|
||||||
|
&beta,
|
||||||
|
c ? c : out,
|
||||||
|
c ? c_desc_ : out_desc_,
|
||||||
|
out,
|
||||||
|
out_desc_,
|
||||||
|
&heuristic_.algo,
|
||||||
|
workspace_ptr,
|
||||||
|
heuristic_.workspaceSize,
|
||||||
|
encoder.stream()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Matmul::run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const std::optional<array>& c /* = std::nullopt */,
|
||||||
|
float alpha /* = 1 */,
|
||||||
|
float beta /* = 0 */) {
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
if (c) {
|
||||||
|
encoder.set_input_array(*c);
|
||||||
|
}
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
run_impl(
|
||||||
|
encoder,
|
||||||
|
out.data<void>(),
|
||||||
|
a.data<void>(),
|
||||||
|
b.data<void>(),
|
||||||
|
c ? c->data<void>() : nullptr,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
100
mlx/backend/cuda/gemms/cublas_gemm.h
Normal file
100
mlx/backend/cuda/gemms/cublas_gemm.h
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
class Matmul {
|
||||||
|
public:
|
||||||
|
Matmul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride);
|
||||||
|
|
||||||
|
Matmul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
int64_t ldc,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride,
|
||||||
|
int64_t c_batch_stride);
|
||||||
|
|
||||||
|
~Matmul();
|
||||||
|
|
||||||
|
void run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const std::optional<array>& c = std::nullopt,
|
||||||
|
float alpha = 1,
|
||||||
|
float beta = 0);
|
||||||
|
|
||||||
|
void run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides);
|
||||||
|
|
||||||
|
void run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const mlx::core::Shape& batch_shape,
|
||||||
|
const mlx::core::Strides& a_batch_strides,
|
||||||
|
const mlx::core::Strides& b_batch_strides,
|
||||||
|
const mlx::core::Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void run_impl(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
void* out,
|
||||||
|
const void* a,
|
||||||
|
const void* b,
|
||||||
|
const void* c,
|
||||||
|
float alpha = 1,
|
||||||
|
float beta = 0);
|
||||||
|
|
||||||
|
uint64_t M_;
|
||||||
|
uint64_t N_;
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
|
cublasLtHandle_t handle_{nullptr};
|
||||||
|
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t out_desc_{nullptr};
|
||||||
|
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/gemv.h"
|
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
146
mlx/backend/cuda/lru_cache.h
Normal file
146
mlx/backend/cuda/lru_cache.h
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename K,
|
||||||
|
typename V,
|
||||||
|
template <typename...> typename M = std::unordered_map>
|
||||||
|
class LRUCache {
|
||||||
|
public:
|
||||||
|
using value_type = std::pair<K, V>;
|
||||||
|
using list_type = std::list<value_type>;
|
||||||
|
using iterator = typename list_type::iterator;
|
||||||
|
using const_iterator = typename list_type::const_iterator;
|
||||||
|
using map_type = M<K, iterator>;
|
||||||
|
|
||||||
|
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return map_.size();
|
||||||
|
}
|
||||||
|
size_t capacity() const {
|
||||||
|
return capacity_;
|
||||||
|
}
|
||||||
|
bool empty() const {
|
||||||
|
return vlist_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize(size_t new_capacity) {
|
||||||
|
capacity_ = new_capacity;
|
||||||
|
trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator begin() {
|
||||||
|
return vlist_.begin();
|
||||||
|
}
|
||||||
|
const_iterator begin() const {
|
||||||
|
return vlist_.begin();
|
||||||
|
}
|
||||||
|
iterator end() {
|
||||||
|
return vlist_.end();
|
||||||
|
}
|
||||||
|
const_iterator end() const {
|
||||||
|
return vlist_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
map_.clear();
|
||||||
|
vlist_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator find(const K& key) {
|
||||||
|
auto it = map_.find(key);
|
||||||
|
if (it == map_.end())
|
||||||
|
return end();
|
||||||
|
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
std::pair<iterator, bool> emplace(const K& key, U&& value) {
|
||||||
|
auto it = map_.find(key);
|
||||||
|
if (it != map_.end()) {
|
||||||
|
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||||
|
return {it->second, false};
|
||||||
|
}
|
||||||
|
|
||||||
|
vlist_.emplace_front(key, std::forward<U>(value));
|
||||||
|
map_[key] = vlist_.begin();
|
||||||
|
|
||||||
|
trim();
|
||||||
|
|
||||||
|
return {vlist_.begin(), true};
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator erase(iterator pos) {
|
||||||
|
map_.erase(pos->first);
|
||||||
|
return vlist_.erase(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void trim() {
|
||||||
|
while (map_.size() > capacity_) {
|
||||||
|
auto last = std::prev(vlist_.end());
|
||||||
|
map_.erase(last->first);
|
||||||
|
vlist_.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
list_type vlist_;
|
||||||
|
map_type map_;
|
||||||
|
size_t capacity_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turn a POD struct into a container key by doing bytes compare.
|
||||||
|
template <typename T>
|
||||||
|
struct BytesKey {
|
||||||
|
T pod;
|
||||||
|
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||||
|
|
||||||
|
BytesKey(T pod) : pod(std::move(pod)) {}
|
||||||
|
|
||||||
|
BytesKey(const BytesKey& other) {
|
||||||
|
memcpy(&pod, &other.pod, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
BytesKey(BytesKey&& other) {
|
||||||
|
memcpy(&pod, &other.pod, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const BytesKey& other) const {
|
||||||
|
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
|
||||||
|
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
|
||||||
|
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute hash according to the bytes value of T.
|
||||||
|
template <typename T>
|
||||||
|
struct BytesHash {
|
||||||
|
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||||
|
|
||||||
|
size_t operator()(const T& pod) const {
|
||||||
|
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
|
||||||
|
uint32_t value = 0x811C9DC5;
|
||||||
|
for (int i = 0; i < sizeof(T); ++i) {
|
||||||
|
value ^= ptr[i];
|
||||||
|
value *= 0x01000193;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename K, typename V>
|
||||||
|
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
|
||||||
|
|
||||||
|
template <typename K, typename V>
|
||||||
|
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -2,290 +2,15 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/matmul.h"
|
#include "mlx/backend/common/matmul.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemv.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
#include <cublasLt.h>
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
|
||||||
|
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
|
||||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
|
||||||
// TODO: Use cublasGetStatusString when it is widely available.
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CublasPreference {
|
|
||||||
CublasPreference(Device& device) {
|
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
|
||||||
// for Hopper+:
|
|
||||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
|
||||||
uint64_t MiB = 1024 * 1024;
|
|
||||||
uint64_t workspace_size =
|
|
||||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
|
||||||
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
|
||||||
pref_,
|
|
||||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
|
||||||
&workspace_size,
|
|
||||||
sizeof(uint64_t)));
|
|
||||||
}
|
|
||||||
|
|
||||||
~CublasPreference() {
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
|
|
||||||
}
|
|
||||||
|
|
||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
|
||||||
};
|
|
||||||
|
|
||||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
|
||||||
static CublasPreference pref(device);
|
|
||||||
return pref.pref_;
|
|
||||||
}
|
|
||||||
|
|
||||||
class MatMul {
|
|
||||||
public:
|
|
||||||
MatMul(
|
|
||||||
Device& device,
|
|
||||||
Dtype dtype,
|
|
||||||
bool a_transposed,
|
|
||||||
uint64_t a_rows,
|
|
||||||
uint64_t a_cols,
|
|
||||||
int64_t lda,
|
|
||||||
bool b_transposed,
|
|
||||||
uint64_t b_rows,
|
|
||||||
uint64_t b_cols,
|
|
||||||
int64_t ldb,
|
|
||||||
int32_t batch_count,
|
|
||||||
int64_t a_batch_stride,
|
|
||||||
int64_t b_batch_stride)
|
|
||||||
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
|
|
||||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
|
||||||
|
|
||||||
auto scale_type = dtype_to_cuda_type(dtype);
|
|
||||||
if (dtype == bfloat16 || dtype == float16) {
|
|
||||||
scale_type = CUDA_R_32F;
|
|
||||||
}
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
|
||||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
|
||||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
|
||||||
matmul_desc_,
|
|
||||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
|
||||||
&pointer_mode,
|
|
||||||
sizeof(int32_t)));
|
|
||||||
cublasOperation_t op = CUBLAS_OP_N;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
|
||||||
matmul_desc_,
|
|
||||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
|
||||||
&op,
|
|
||||||
sizeof(cublasOperation_t)));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
|
||||||
matmul_desc_,
|
|
||||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
|
||||||
&op,
|
|
||||||
sizeof(cublasOperation_t)));
|
|
||||||
|
|
||||||
auto type = dtype_to_cuda_type(dtype);
|
|
||||||
a_desc_ = create_matrix_layout(
|
|
||||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
|
||||||
b_desc_ = create_matrix_layout(
|
|
||||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
|
||||||
out_desc_ = create_matrix_layout(
|
|
||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
MatMul(
|
|
||||||
Device& device,
|
|
||||||
Dtype dtype,
|
|
||||||
bool a_transposed,
|
|
||||||
uint64_t a_rows,
|
|
||||||
uint64_t a_cols,
|
|
||||||
int64_t lda,
|
|
||||||
bool b_transposed,
|
|
||||||
uint64_t b_rows,
|
|
||||||
uint64_t b_cols,
|
|
||||||
int64_t ldb,
|
|
||||||
int64_t ldc,
|
|
||||||
int32_t batch_count,
|
|
||||||
int64_t a_batch_stride,
|
|
||||||
int64_t b_batch_stride,
|
|
||||||
int64_t c_batch_stride)
|
|
||||||
: MatMul(
|
|
||||||
device,
|
|
||||||
dtype,
|
|
||||||
a_transposed,
|
|
||||||
a_rows,
|
|
||||||
a_cols,
|
|
||||||
lda,
|
|
||||||
b_transposed,
|
|
||||||
b_rows,
|
|
||||||
b_cols,
|
|
||||||
ldb,
|
|
||||||
batch_count,
|
|
||||||
a_batch_stride,
|
|
||||||
b_batch_stride) {
|
|
||||||
auto type = dtype_to_cuda_type(dtype);
|
|
||||||
c_desc_ = create_matrix_layout(
|
|
||||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
~MatMul() {
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
|
||||||
}
|
|
||||||
|
|
||||||
void run(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
void* out,
|
|
||||||
void* a,
|
|
||||||
void* b,
|
|
||||||
void* c = nullptr,
|
|
||||||
float alpha = 1,
|
|
||||||
float beta = 0) {
|
|
||||||
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
|
|
||||||
int ret = 0;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
|
|
||||||
handle_,
|
|
||||||
matmul_desc_,
|
|
||||||
a_desc_,
|
|
||||||
b_desc_,
|
|
||||||
out_desc_,
|
|
||||||
out_desc_,
|
|
||||||
pref_,
|
|
||||||
1,
|
|
||||||
&heuristic_,
|
|
||||||
&ret));
|
|
||||||
if (ret == 0) {
|
|
||||||
throw std::runtime_error("Can not find algorithm for matmul.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void* workspace_ptr = nullptr;
|
|
||||||
if (heuristic_.workspaceSize > 0) {
|
|
||||||
array workspace(
|
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
|
||||||
int8);
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
workspace_ptr = workspace.data<void>();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
|
||||||
handle_,
|
|
||||||
matmul_desc_,
|
|
||||||
&alpha,
|
|
||||||
a,
|
|
||||||
a_desc_,
|
|
||||||
b,
|
|
||||||
b_desc_,
|
|
||||||
&beta,
|
|
||||||
c ? c : out,
|
|
||||||
c ? c_desc_ : out_desc_,
|
|
||||||
out,
|
|
||||||
out_desc_,
|
|
||||||
&heuristic_.algo,
|
|
||||||
workspace_ptr,
|
|
||||||
heuristic_.workspaceSize,
|
|
||||||
encoder.stream()));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case float16:
|
|
||||||
return CUBLAS_COMPUTE_32F;
|
|
||||||
case bfloat16:
|
|
||||||
return CUBLAS_COMPUTE_32F;
|
|
||||||
case float32:
|
|
||||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
|
||||||
: CUBLAS_COMPUTE_32F;
|
|
||||||
case float64:
|
|
||||||
case complex64:
|
|
||||||
return CUBLAS_COMPUTE_64F;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case float16:
|
|
||||||
return CUDA_R_16F;
|
|
||||||
case bfloat16:
|
|
||||||
return CUDA_R_16BF;
|
|
||||||
case float32:
|
|
||||||
return CUDA_R_32F;
|
|
||||||
case float64:
|
|
||||||
return CUDA_R_64F;
|
|
||||||
case complex64:
|
|
||||||
return CUDA_C_32F;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cublasLtMatrixLayout_t create_matrix_layout(
|
|
||||||
cudaDataType_t type,
|
|
||||||
uint64_t rows,
|
|
||||||
uint64_t cols,
|
|
||||||
bool transposed,
|
|
||||||
int64_t ld,
|
|
||||||
int32_t batch_count,
|
|
||||||
int64_t batch_stride) {
|
|
||||||
cublasLtMatrixLayout_t desc;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
|
||||||
cublasLtOrder_t order =
|
|
||||||
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
|
||||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
|
||||||
if (batch_count > 1) {
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
|
||||||
desc,
|
|
||||||
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
|
||||||
&batch_count,
|
|
||||||
sizeof(int32_t)));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
|
||||||
desc,
|
|
||||||
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
|
||||||
&batch_stride,
|
|
||||||
sizeof(int64_t)));
|
|
||||||
}
|
|
||||||
return desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
|
||||||
cublasLtHandle_t handle_{nullptr};
|
|
||||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
|
||||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
|
||||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
|
||||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
|
||||||
cublasLtMatrixLayout_t out_desc_{nullptr};
|
|
||||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::tuple<bool, int64_t, array>
|
std::tuple<bool, int64_t, array>
|
||||||
@@ -372,8 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
|
cu::Matmul matmul(
|
||||||
cu::MatMul matmul(
|
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -388,27 +112,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
encoder.set_input_array(a);
|
if ((batch_count / batch_shape.back()) == 1) {
|
||||||
encoder.set_input_array(b);
|
matmul.run(encoder, out, a, b);
|
||||||
encoder.set_output_array(out);
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
|
||||||
if (nbatch == 1) {
|
|
||||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
matmul.run_batched(
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
auto concurrent = encoder.concurrent_context();
|
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
|
||||||
matmul.run(
|
|
||||||
encoder,
|
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
|
||||||
b.data<int8_t>() + b.itemsize() * b_it.loc);
|
|
||||||
a_it.step();
|
|
||||||
b_it.step();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -476,7 +186,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
|
|
||||||
cu::MatMul matmul(
|
cu::Matmul matmul(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -493,41 +203,22 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
encoder.set_input_array(a);
|
if ((batch_count / batch_shape.back()) == 1) {
|
||||||
encoder.set_input_array(b);
|
matmul.run(encoder, out, a, b, c, alpha_, beta_);
|
||||||
encoder.set_input_array(c);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
|
||||||
if (nbatch == 1) {
|
|
||||||
matmul.run(
|
|
||||||
encoder,
|
|
||||||
out.data<int8_t>(),
|
|
||||||
a.data<int8_t>(),
|
|
||||||
b.data<int8_t>(),
|
|
||||||
c.data<int8_t>(),
|
|
||||||
alpha_,
|
|
||||||
beta_);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
matmul.run_batched(
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
encoder,
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
out,
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
a,
|
||||||
auto concurrent = encoder.concurrent_context();
|
b,
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
c,
|
||||||
matmul.run(
|
batch_shape,
|
||||||
encoder,
|
a_batch_strides,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
b_batch_strides,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
c_batch_strides,
|
||||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
alpha_,
|
||||||
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
beta_);
|
||||||
alpha_,
|
|
||||||
beta_);
|
|
||||||
a_it.step();
|
|
||||||
b_it.step();
|
|
||||||
c_it.step();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(Convolution)
|
|
||||||
NO_GPU(DynamicSlice)
|
NO_GPU(DynamicSlice)
|
||||||
NO_GPU(DynamicSliceUpdate)
|
NO_GPU(DynamicSliceUpdate)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
|
|||||||
@@ -17,6 +17,14 @@ CudaStream::~CudaStream() {
|
|||||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void check_cuda_error(const char* name, cudaError_t err) {
|
void check_cuda_error(const char* name, cudaError_t err) {
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
@@ -33,10 +34,12 @@ class CudaStream {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
// Throw exception if the cuda API does not succeed.
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
void check_cuda_error(const char* name, CUresult err);
|
void check_cuda_error(const char* name, CUresult err);
|
||||||
|
|
||||||
// The macro version that prints the command that failed.
|
// The macro version that prints the command that failed.
|
||||||
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
|
||||||
// Convert Dtype to CUDA C++ types.
|
// Convert Dtype to CUDA C++ types.
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) {
|
|||||||
auto p = metal::new_scoped_memory_pool();
|
auto p = metal::new_scoped_memory_pool();
|
||||||
event_ = std::shared_ptr<void>(
|
event_ = std::shared_ptr<void>(
|
||||||
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
||||||
|
if (event_ == nullptr) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Event::Event] Failed to create Metal shared event.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Event::wait() {
|
void Event::wait() {
|
||||||
|
|||||||
@@ -265,9 +265,15 @@ void qvm_split_k(
|
|||||||
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
||||||
|
|
||||||
int x_batch_ndims = x.ndim() - 2;
|
|
||||||
auto x_shape = x.shape();
|
auto x_shape = x.shape();
|
||||||
auto x_strides = x.strides();
|
auto x_strides = x.strides();
|
||||||
|
if (x_shape.size() == 1) {
|
||||||
|
x_shape.insert(x_shape.begin(), 1);
|
||||||
|
x_strides.insert(x_strides.begin(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int x_ndim = x_shape.size();
|
||||||
|
int x_batch_ndims = x_ndim - 2;
|
||||||
int w_batch_ndims = w.ndim() - 2;
|
int w_batch_ndims = w.ndim() - 2;
|
||||||
auto w_shape = w.shape();
|
auto w_shape = w.shape();
|
||||||
auto w_strides = w.strides();
|
auto w_strides = w.strides();
|
||||||
@@ -278,7 +284,7 @@ void qvm_split_k(
|
|||||||
x_shape.insert(x_shape.end() - 2, split_k);
|
x_shape.insert(x_shape.end() - 2, split_k);
|
||||||
x_shape.back() /= split_k;
|
x_shape.back() /= split_k;
|
||||||
x_strides.insert(x_strides.end() - 2, split_D);
|
x_strides.insert(x_strides.end() - 2, split_D);
|
||||||
x_strides[x.ndim() - 1] = split_D;
|
x_strides[x_ndim - 1] = split_D;
|
||||||
x_batch_ndims += 1;
|
x_batch_ndims += 1;
|
||||||
|
|
||||||
w_shape.insert(w_shape.end() - 2, split_k);
|
w_shape.insert(w_shape.end() - 2, split_k);
|
||||||
@@ -291,6 +297,9 @@ void qvm_split_k(
|
|||||||
int final_block_size = K - (split_k - 1) * split_D;
|
int final_block_size = K - (split_k - 1) * split_D;
|
||||||
|
|
||||||
auto temp_shape = out.shape();
|
auto temp_shape = out.shape();
|
||||||
|
if (temp_shape.size() == 1) {
|
||||||
|
temp_shape.insert(temp_shape.begin(), 1);
|
||||||
|
}
|
||||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
|||||||
@@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
|
|
||||||
// Stream events for synchronization after eval
|
// Stream events for synchronization after eval
|
||||||
std::unordered_map<uint32_t, Event> events;
|
std::unordered_map<uint32_t, Event> events;
|
||||||
events.emplace(stream.index, Event{stream});
|
{
|
||||||
|
auto e = Event{stream};
|
||||||
|
e.set_value(1);
|
||||||
|
synchronizer.attach_event(e);
|
||||||
|
events.emplace(stream.index, std::move(e));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// Record the degree of each input
|
// Record the degree of each input
|
||||||
@@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_set<int> open_streams;
|
||||||
|
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
auto arr = std::move(tape.back());
|
auto arr = std::move(tape.back());
|
||||||
tape.pop_back();
|
tape.pop_back();
|
||||||
|
|
||||||
auto stream = arr.primitive().stream();
|
auto stream = arr.primitive().stream();
|
||||||
|
open_streams.insert(stream.index);
|
||||||
|
|
||||||
// Lookup corresponding event
|
if (async) {
|
||||||
auto e = events.find(stream.index);
|
// Lookup corresponding event
|
||||||
if (e == events.end()) {
|
auto e = events.find(stream.index);
|
||||||
e = events.emplace(stream.index, Event{stream}).first;
|
if (e == events.end()) {
|
||||||
}
|
e = events.emplace(stream.index, Event{stream}).first;
|
||||||
e->second.set_value(1);
|
}
|
||||||
arr.attach_event(e->second);
|
e->second.set_value(1);
|
||||||
for (auto& s : arr.siblings()) {
|
arr.attach_event(e->second);
|
||||||
s.attach_event(e->second);
|
for (auto& s : arr.siblings()) {
|
||||||
|
s.attach_event(e->second);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
@@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
(get_active_memory() > get_memory_limit() &&
|
(get_active_memory() > get_memory_limit() &&
|
||||||
scheduler::n_active_tasks() > 0)) {
|
scheduler::n_active_tasks() > 0)) {
|
||||||
// Commit any open streams
|
// Commit any open streams
|
||||||
for (auto& [_, e] : events) {
|
for (auto i : open_streams) {
|
||||||
if (e.stream().device == Device::gpu) {
|
auto s = get_stream(i);
|
||||||
gpu::finalize(e.stream());
|
if (s.device == Device::gpu) {
|
||||||
|
gpu::finalize(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
@@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Signal the event in its stream
|
// Signal the event in its stream
|
||||||
for (auto& [_, e] : events) {
|
for (auto i : open_streams) {
|
||||||
auto s = e.stream();
|
auto s = get_stream(i);
|
||||||
e.signal(s);
|
if (auto e = events.find(i); e != events.end()) {
|
||||||
|
e->second.signal(s);
|
||||||
|
}
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
gpu::finalize(s);
|
gpu::finalize(s);
|
||||||
}
|
}
|
||||||
@@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_impl(std::move(outputs), false).event().wait();
|
eval_impl(std::move(outputs), false).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 26
|
#define MLX_VERSION_MINOR 27
|
||||||
#define MLX_VERSION_PATCH 5
|
#define MLX_VERSION_PATCH 1
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -477,7 +477,7 @@ class Adam(Optimizer):
|
|||||||
|
|
||||||
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
||||||
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
||||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||||
@@ -546,7 +546,7 @@ class AdamW(Adam):
|
|||||||
|
|
||||||
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
||||||
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
||||||
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda w_t)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float or callable): The learning rate :math:`\alpha`.
|
learning_rate (float or callable): The learning rate :math:`\alpha`.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ auditwheel repair dist/* \
|
|||||||
--exclude libcublas* \
|
--exclude libcublas* \
|
||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
--exclude libcuda* \
|
--exclude libcuda* \
|
||||||
|
--exclude libcudnn* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ rm "${repaired_wheel}"
|
|||||||
mlx_so="mlx/lib/libmlx.so"
|
mlx_so="mlx/lib/libmlx.so"
|
||||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||||
base="\$ORIGIN/../../nvidia"
|
base="\$ORIGIN/../../nvidia"
|
||||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib
|
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||||
python ../python/scripts/repair_record.py ${mlx_so}
|
python ../python/scripts/repair_record.py ${mlx_so}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
auditwheel repair dist/* \
|
auditwheel repair dist/* \
|
||||||
--plat manylinux_2_35_x86_64 \
|
--plat manylinux_2_35_x86_64 \
|
||||||
|
--only-plat \
|
||||||
--exclude libmlx* \
|
--exclude libmlx* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|||||||
@@ -15,19 +15,12 @@ cuda_skip = {
|
|||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
|
||||||
"TestConv.test_basic_grad_shapes",
|
|
||||||
"TestConv.test_conv2d_unaligned_channels",
|
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
"TestConv.test_conv_1d_groups_flipped",
|
||||||
"TestConv.test_conv_general_flip_grad",
|
"TestConv.test_conv_general_flip_grad",
|
||||||
"TestConv.test_conv_groups_grad",
|
"TestConv.test_conv_groups_grad",
|
||||||
"TestConv.test_numpy_conv",
|
|
||||||
"TestConv.test_repeated_conv",
|
|
||||||
"TestConv.test_torch_conv_1D",
|
|
||||||
"TestConv.test_torch_conv_1D_grad",
|
"TestConv.test_torch_conv_1D_grad",
|
||||||
"TestConv.test_torch_conv_2D",
|
"TestConv.test_torch_conv_2D",
|
||||||
"TestConv.test_torch_conv_2D_grad",
|
"TestConv.test_torch_conv_2D_grad",
|
||||||
"TestConv.test_torch_conv_3D",
|
|
||||||
"TestConv.test_torch_conv_3D_grad",
|
"TestConv.test_torch_conv_3D_grad",
|
||||||
"TestConv.test_torch_conv_depthwise",
|
"TestConv.test_torch_conv_depthwise",
|
||||||
"TestConv.test_torch_conv_general",
|
"TestConv.test_torch_conv_general",
|
||||||
@@ -40,10 +33,6 @@ cuda_skip = {
|
|||||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||||
"TestExportImport.test_export_conv",
|
|
||||||
"TestLayers.test_conv1d",
|
|
||||||
"TestLayers.test_conv2d",
|
|
||||||
"TestVmap.test_vmap_conv",
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
|||||||
@@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||||
|
|
||||||
|
# Test with 1D vector
|
||||||
|
group_size = 32
|
||||||
|
bits = 8
|
||||||
|
N = 2048
|
||||||
|
x = 1e-1 * mx.random.normal(shape=(N,), key=k1)
|
||||||
|
w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)
|
||||||
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
|
||||||
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||||
|
|
||||||
def test_throw(self):
|
def test_throw(self):
|
||||||
x = mx.random.normal(shape=(10, 512))
|
x = mx.random.normal(shape=(10, 512))
|
||||||
w = mx.random.normal(shape=(32, 512))
|
w = mx.random.normal(shape=(32, 512))
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -289,6 +289,7 @@ if __name__ == "__main__":
|
|||||||
install_requires += [
|
install_requires += [
|
||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
|
"nvidia-cudnn-cu12==9.*",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user