mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
5 Commits
7fde1b6a1e
...
jagrit06/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
400f8457ea | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 |
@@ -1,4 +1,5 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
|
sphinx-copybutton
|
||||||
mlx
|
mlx
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
|
"sphinx_copybutton",
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
|
|||||||
@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
case uint8:
|
case uint8:
|
||||||
|
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
case uint16:
|
|
||||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
case uint32:
|
|
||||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
case uint64:
|
|
||||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/steel_gemm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
@@ -53,10 +54,10 @@ target_sources(
|
|||||||
|
|
||||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|||||||
@@ -7,10 +7,12 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
struct CublasPreference {
|
struct CublasPreference {
|
||||||
CublasPreference(Device& device) {
|
CublasPreference(cu::Device& device) {
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
// for Hopper+:
|
// for Hopper+:
|
||||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
@@ -33,7 +35,7 @@ struct CublasPreference {
|
|||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
|
||||||
static CublasPreference pref(device);
|
static CublasPreference pref(device);
|
||||||
return pref.pref_;
|
return pref.pref_;
|
||||||
}
|
}
|
||||||
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
|||||||
return CUBLAS_COMPUTE_64F;
|
return CUBLAS_COMPUTE_64F;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
|
|||||||
return CUDA_C_32F;
|
return CUDA_C_32F;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
|||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::Matmul(
|
} // namespace
|
||||||
Device& device,
|
|
||||||
|
CublasGemm::CublasGemm(
|
||||||
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -155,8 +159,8 @@ Matmul::Matmul(
|
|||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::Matmul(
|
CublasGemm::CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -171,7 +175,7 @@ Matmul::Matmul(
|
|||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride,
|
int64_t b_batch_stride,
|
||||||
int64_t c_batch_stride)
|
int64_t c_batch_stride)
|
||||||
: Matmul(
|
: CublasGemm(
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -190,7 +194,7 @@ Matmul::Matmul(
|
|||||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::~Matmul() {
|
CublasGemm::~CublasGemm() {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||||
@@ -198,7 +202,73 @@ Matmul::~Matmul() {
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_impl(
|
void CublasGemm::run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
|
run_batched(
|
||||||
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasGemm::run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides,
|
||||||
|
const Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
|
run_batched(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
c_batch_strides,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
execute(
|
||||||
|
encoder,
|
||||||
|
out.data<void>(),
|
||||||
|
a.data<void>(),
|
||||||
|
b.data<void>(),
|
||||||
|
c.data<void>(),
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasGemm::execute(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
void* out,
|
void* out,
|
||||||
const void* a,
|
const void* a,
|
||||||
@@ -256,29 +326,4 @@ void Matmul::run_impl(
|
|||||||
encoder.stream()));
|
encoder.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run(
|
} // namespace mlx::core
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
array& out,
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
const std::optional<array>& c /* = std::nullopt */,
|
|
||||||
float alpha /* = 1 */,
|
|
||||||
float beta /* = 0 */) {
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
if (c) {
|
|
||||||
encoder.set_input_array(*c);
|
|
||||||
}
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
run_impl(
|
|
||||||
encoder,
|
|
||||||
out.data<void>(),
|
|
||||||
a.data<void>(),
|
|
||||||
b.data<void>(),
|
|
||||||
c ? c->data<void>() : nullptr,
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <optional>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
class Matmul {
|
|
||||||
|
class CublasGemm {
|
||||||
public:
|
public:
|
||||||
Matmul(
|
CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -25,8 +25,8 @@ class Matmul {
|
|||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride);
|
int64_t b_batch_stride);
|
||||||
|
|
||||||
Matmul(
|
CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -42,25 +42,39 @@ class Matmul {
|
|||||||
int64_t b_batch_stride,
|
int64_t b_batch_stride,
|
||||||
int64_t c_batch_stride);
|
int64_t c_batch_stride);
|
||||||
|
|
||||||
~Matmul();
|
~CublasGemm();
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const std::optional<array>& c = std::nullopt,
|
const Shape& batch_shape,
|
||||||
float alpha = 1,
|
const Strides& a_batch_strides,
|
||||||
float beta = 0);
|
const Strides& b_batch_strides);
|
||||||
|
|
||||||
|
void run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides,
|
||||||
|
const Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta);
|
||||||
|
|
||||||
|
private:
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides);
|
const Strides& b_batch_strides);
|
||||||
|
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@@ -68,15 +82,14 @@ class Matmul {
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta);
|
float beta);
|
||||||
|
|
||||||
private:
|
void execute(
|
||||||
void run_impl(
|
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
void* out,
|
void* out,
|
||||||
const void* a,
|
const void* a,
|
||||||
@@ -97,4 +110,4 @@ class Matmul {
|
|||||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -4,16 +4,16 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
const Strides& b_batch_strides) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -22,7 +22,7 @@ void Matmul::run_batched(
|
|||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
auto concurrent = encoder.concurrent_context();
|
auto concurrent = encoder.concurrent_context();
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
@@ -33,16 +33,16 @@ void Matmul::run_batched(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
@@ -56,7 +56,7 @@ void Matmul::run_batched(
|
|||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
auto concurrent = encoder.concurrent_context();
|
auto concurrent = encoder.concurrent_context();
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
@@ -70,4 +70,4 @@ void Matmul::run_batched(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
||||||
@@ -6,7 +6,9 @@
|
|||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
@@ -128,6 +130,10 @@ __global__ void set_addmm_device_pointers_g(
|
|||||||
out_start + item_size * index * batch_stride;
|
out_start + item_size * index * batch_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
@@ -139,14 +145,16 @@ void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
|||||||
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_batched(
|
} // namespace
|
||||||
|
|
||||||
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
const Strides& b_batch_strides) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
@@ -213,7 +221,7 @@ void Matmul::run_batched(
|
|||||||
auto a_pointers = pointers.data<int8_t*>();
|
auto a_pointers = pointers.data<int8_t*>();
|
||||||
auto b_pointers = a_pointers + batch_count;
|
auto b_pointers = a_pointers + batch_count;
|
||||||
auto out_pointers = b_pointers + batch_count;
|
auto out_pointers = b_pointers + batch_count;
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
reinterpret_cast<void*>(out_pointers),
|
reinterpret_cast<void*>(out_pointers),
|
||||||
reinterpret_cast<void*>(a_pointers),
|
reinterpret_cast<void*>(a_pointers),
|
||||||
@@ -221,16 +229,16 @@ void Matmul::run_batched(
|
|||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
@@ -306,7 +314,7 @@ void Matmul::run_batched(
|
|||||||
auto b_pointers = a_pointers + batch_count;
|
auto b_pointers = a_pointers + batch_count;
|
||||||
auto c_pointers = b_pointers + batch_count;
|
auto c_pointers = b_pointers + batch_count;
|
||||||
auto out_pointers = c_pointers + batch_count;
|
auto out_pointers = c_pointers + batch_count;
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
reinterpret_cast<void*>(out_pointers),
|
reinterpret_cast<void*>(out_pointers),
|
||||||
reinterpret_cast<void*>(a_pointers),
|
reinterpret_cast<void*>(a_pointers),
|
||||||
@@ -316,4 +324,4 @@ void Matmul::run_batched(
|
|||||||
beta);
|
beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
||||||
301
mlx/backend/cuda/gemms/steel_gemm.cu
Normal file
301
mlx/backend/cuda/gemms/steel_gemm.cu
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
#include "mlx/backend/common/matmul.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/mma.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
struct GemmParams {
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldb;
|
||||||
|
int ldd;
|
||||||
|
|
||||||
|
int NblockM;
|
||||||
|
int NblockN;
|
||||||
|
int NblockK;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int SL,
|
||||||
|
int Nstages>
|
||||||
|
__global__ void kernel_steel_gemm(
|
||||||
|
const T* a,
|
||||||
|
const T* b,
|
||||||
|
T* d,
|
||||||
|
__grid_constant__ const GemmParams params) {
|
||||||
|
const int bM_idx = (blockIdx.y << SL) + (blockIdx.x & ((1 << SL) - 1));
|
||||||
|
const int bN_idx = blockIdx.x >> SL;
|
||||||
|
|
||||||
|
if (params.NblockN <= bN_idx || params.NblockM <= bM_idx) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int d_row = bM_idx * BM;
|
||||||
|
const int d_col = bN_idx * BN;
|
||||||
|
const size_t d_row_long = size_t(d_row);
|
||||||
|
const size_t d_col_long = size_t(d_col);
|
||||||
|
|
||||||
|
a += transpose_a ? d_row_long : d_row_long * params.K;
|
||||||
|
b += transpose_b ? d_col_long * params.K : d_col_long;
|
||||||
|
d += d_row_long * params.ldd + d_col_long;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
const int wm = warp_idx / WN;
|
||||||
|
const int wn = warp_idx % WN;
|
||||||
|
|
||||||
|
constexpr int SM = BM / WM;
|
||||||
|
constexpr int SN = BN / WN;
|
||||||
|
constexpr int SK = BK;
|
||||||
|
constexpr int TK = SK / 16;
|
||||||
|
|
||||||
|
constexpr int NUM_WARPS = WM * WN;
|
||||||
|
|
||||||
|
// Allocate shared memory
|
||||||
|
extern __shared__ char shmem[];
|
||||||
|
SharedTile<T, BM, BK>(&as)[Nstages] =
|
||||||
|
*(SharedTile<T, BM, BK>(*)[Nstages])(&shmem[0]);
|
||||||
|
SharedTile<T, BN, BK>(&bs)[Nstages] = *(SharedTile<T, BN, BK>(*)[Nstages])(
|
||||||
|
&shmem[sizeof(T) * Nstages * BM * BK]);
|
||||||
|
|
||||||
|
// Allocate registers for the MMA
|
||||||
|
RegisterTile<float, SM, SN> C;
|
||||||
|
RegisterTile<T, SM, 16> A[TK];
|
||||||
|
RegisterTile<T, SN, 16> B[TK];
|
||||||
|
|
||||||
|
// Zero the accumulators
|
||||||
|
C.fill(0);
|
||||||
|
|
||||||
|
// Start gmem -> smem copies
|
||||||
|
int k_block_read = 0;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int bk = 0; bk < (Nstages - 1); bk++) {
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
as[bk], as[bk].base_addr(), a + k_block_read, params.K);
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
bs[bk], bs[bk].base_addr(), b + k_block_read, params.K);
|
||||||
|
k_block_read += BK;
|
||||||
|
cp_async_commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
int smem_pipe_read = 0;
|
||||||
|
int smem_pipe_write = Nstages - 1;
|
||||||
|
|
||||||
|
// Wait till only 1 remains laoding
|
||||||
|
cp_async_wait<1>();
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
const int offset_m = wm * SM;
|
||||||
|
const int offset_n = wn * SN;
|
||||||
|
|
||||||
|
// Start smem -> register copy
|
||||||
|
A[0].load(
|
||||||
|
as[smem_pipe_read],
|
||||||
|
as[smem_pipe_read].base_addr(),
|
||||||
|
offset_m + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
B[0].load(
|
||||||
|
bs[smem_pipe_read],
|
||||||
|
bs[smem_pipe_read].base_addr(),
|
||||||
|
offset_n + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
|
||||||
|
// Main loop
|
||||||
|
for (int kb = 0; kb < params.NblockK; kb++) {
|
||||||
|
// Prepare next registers
|
||||||
|
{
|
||||||
|
A[1].load(
|
||||||
|
as[smem_pipe_read],
|
||||||
|
as[smem_pipe_read].base_addr(),
|
||||||
|
offset_m + lane_idx % 16,
|
||||||
|
16 + lane_idx / 16 * 8);
|
||||||
|
B[1].load(
|
||||||
|
bs[smem_pipe_read],
|
||||||
|
bs[smem_pipe_read].base_addr(),
|
||||||
|
offset_n + lane_idx % 16,
|
||||||
|
16 + lane_idx / 16 * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare next smem
|
||||||
|
if ((kb + Nstages - 1) < params.NblockK) {
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
as[smem_pipe_write],
|
||||||
|
as[smem_pipe_write].base_addr(),
|
||||||
|
a + k_block_read,
|
||||||
|
params.K);
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
bs[smem_pipe_write],
|
||||||
|
bs[smem_pipe_write].base_addr(),
|
||||||
|
b + k_block_read,
|
||||||
|
params.K);
|
||||||
|
}
|
||||||
|
k_block_read += BK;
|
||||||
|
|
||||||
|
cp_async_commit();
|
||||||
|
|
||||||
|
smem_pipe_write = smem_pipe_read;
|
||||||
|
smem_pipe_read = smem_pipe_read + 1;
|
||||||
|
smem_pipe_read = (smem_pipe_read == Nstages) ? 0 : smem_pipe_read;
|
||||||
|
|
||||||
|
// Do current gemm
|
||||||
|
mma_t(C, A[0], B[0]);
|
||||||
|
|
||||||
|
// Do wait for next register
|
||||||
|
cp_async_wait<1>();
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
// Prepare next register (smem_pipe_read has moved to the next)
|
||||||
|
{
|
||||||
|
A[0].load(
|
||||||
|
as[smem_pipe_read],
|
||||||
|
as[smem_pipe_read].base_addr(),
|
||||||
|
offset_m + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
B[0].load(
|
||||||
|
bs[smem_pipe_read],
|
||||||
|
bs[smem_pipe_read].base_addr(),
|
||||||
|
offset_n + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do current gemm
|
||||||
|
mma_t(C, A[1], B[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait and clear
|
||||||
|
cp_async_wait_all();
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
C.store_global(d, params.ldd, offset_m, offset_n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void dispatch_steel_gemm(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& d,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
int ldd,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed) {
|
||||||
|
using DataType = cuda_type_t<float16_t>;
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(d);
|
||||||
|
|
||||||
|
constexpr int BM = 128;
|
||||||
|
constexpr int BN = 128;
|
||||||
|
constexpr int BK = 32;
|
||||||
|
|
||||||
|
constexpr int WM = 2;
|
||||||
|
constexpr int WN = 2;
|
||||||
|
|
||||||
|
constexpr int SL = 0;
|
||||||
|
constexpr int Nstages = 3;
|
||||||
|
|
||||||
|
constexpr uint32_t smem_bytes = BK * (BM + BN) * Nstages * sizeof(DataType);
|
||||||
|
|
||||||
|
const int NblockM = (M + BM - 1) / BM;
|
||||||
|
const int NblockN = (N + BN - 1) / BN;
|
||||||
|
const int NblockK = (K + BK - 1) / BK;
|
||||||
|
|
||||||
|
cu::GemmParams params{
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* int ldd = */ ldd,
|
||||||
|
|
||||||
|
/* int NblockM = */ NblockM,
|
||||||
|
/* int NblockN = */ NblockN,
|
||||||
|
/* int NblockK = */ NblockK,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Prepare launch grid params
|
||||||
|
int tile = 1 << SL;
|
||||||
|
int tm = (NblockM + tile - 1) / tile;
|
||||||
|
int tn = NblockN * tile;
|
||||||
|
|
||||||
|
dim3 grid_dim(tn, tm, 1);
|
||||||
|
dim3 block_dim(32 * WM * WN, 1, 1);
|
||||||
|
|
||||||
|
dispatch_bool(a_transposed, [&](auto ta_) {
|
||||||
|
dispatch_bool(b_transposed, [&](auto tb_) {
|
||||||
|
constexpr bool ta = ta_.value;
|
||||||
|
constexpr bool tb = tb_.value;
|
||||||
|
|
||||||
|
auto kernel = cu::ab_t_aligned<DataType, BM, BN, BK>;
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
smem_bytes,
|
||||||
|
a.data<DataType>(),
|
||||||
|
b.data<DataType>(),
|
||||||
|
d.data<DataType>(),
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
|
||||||
|
// auto kernel = cu::kernel_steel_gemm<DataType, BM, BN, BK, WM, WN, ta,
|
||||||
|
// tb, SL, Nstages>;
|
||||||
|
|
||||||
|
// cudaFuncSetAttribute(kernel,
|
||||||
|
// cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||||
|
|
||||||
|
// encoder.add_kernel_node(
|
||||||
|
// kernel,
|
||||||
|
// grid_dim,
|
||||||
|
// block_dim,
|
||||||
|
// smem_bytes,
|
||||||
|
// a.data<DataType>(),
|
||||||
|
// b.data<DataType>(),
|
||||||
|
// d.data<DataType>(),
|
||||||
|
// params);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
27
mlx/backend/cuda/gemms/steel_gemm.h
Normal file
27
mlx/backend/cuda/gemms/steel_gemm.h
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/matmul.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void dispatch_steel_gemm(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& d,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
int ldd,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -7,6 +7,8 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@@ -95,9 +97,27 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (out.dtype() == float16 && batch_count == 1 && !a_transposed &&
|
||||||
|
b_transposed) {
|
||||||
|
return dispatch_steel_gemm(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* cu::CommandEncoder& encoder = */ encoder,
|
||||||
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
|
/* array& d = */ out,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* int ldd = */ N,
|
||||||
|
/* bool a_transposed = */ a_transposed,
|
||||||
|
/* bool b_transposed = */ b_transposed);
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
cu::Matmul matmul(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -111,14 +131,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
batch_shape.back(),
|
batch_shape.back(),
|
||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
if ((batch_count / batch_shape.back()) == 1) {
|
|
||||||
matmul.run(encoder, out, a, b);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
matmul.run_batched(
|
|
||||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -186,7 +199,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
|
|
||||||
cu::Matmul matmul(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -202,12 +215,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
gemm.run(
|
||||||
if ((batch_count / batch_shape.back()) == 1) {
|
|
||||||
matmul.run(encoder, out, a, b, c, alpha_, beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
matmul.run_batched(
|
|
||||||
encoder,
|
encoder,
|
||||||
out,
|
out,
|
||||||
a,
|
a,
|
||||||
|
|||||||
@@ -143,85 +143,87 @@ struct Tile16x16 {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
// /**
|
||||||
* A simple container of multiple Tile16x16.
|
// * A simple container of multiple Tile16x16.
|
||||||
*
|
// *
|
||||||
* Provides utility functions for loading and manipulating collections of basic
|
// * Provides utility functions for loading and manipulating collections of
|
||||||
* tiles.
|
// basic
|
||||||
*/
|
// * tiles.
|
||||||
template <typename T, int ROWS_, int COLS_>
|
// */
|
||||||
struct RegisterTile {
|
// template <typename T, int ROWS_, int COLS_>
|
||||||
static constexpr int ROWS = ROWS_;
|
// struct RegisterTile {
|
||||||
static constexpr int COLS = COLS_;
|
// static constexpr int ROWS = ROWS_;
|
||||||
static constexpr int TILES_X = COLS / 16;
|
// static constexpr int COLS = COLS_;
|
||||||
static constexpr int TILES_Y = ROWS / 16;
|
// static constexpr int TILES_X = COLS / 16;
|
||||||
|
// static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
|
||||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
// Tile16x16<T> data[TILES_X * TILES_Y];
|
||||||
|
|
||||||
__device__ inline void fill(T v) {
|
// __device__ inline void fill(T v) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].fill(v);
|
// data[i * TILES_X + j].fill(v);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename Tile>
|
// template <typename Tile>
|
||||||
__device__ __forceinline__ void
|
// __device__ __forceinline__ void
|
||||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
// load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].load(
|
// data[i * TILES_X + j].load(
|
||||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
// tile.loc(base_address, row + i * 16, col + j * 16));
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename Tile, typename F>
|
// template <typename Tile, typename F>
|
||||||
__device__ __forceinline__ void
|
// __device__ __forceinline__ void
|
||||||
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
// load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
f(data[i * TILES_X + j],
|
// f(data[i * TILES_X + j],
|
||||||
tile,
|
// tile,
|
||||||
base_address,
|
// base_address,
|
||||||
row + i * 16,
|
// row + i * 16,
|
||||||
col + j * 16);
|
// col + j * 16);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename U>
|
// template <typename U>
|
||||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
// __device__ inline void store_global(U* x, int N, int row, int col) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].store_global(
|
// data[i * TILES_X + j].store_global(
|
||||||
x + (row + i * 16) * N + col + j * 16, N);
|
// x + (row + i * 16) * N + col + j * 16, N);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename U>
|
// template <typename U>
|
||||||
__device__ inline void
|
// __device__ inline void
|
||||||
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
// store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].store_global_safe(
|
// data[i * TILES_X + j].store_global_safe(
|
||||||
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
// x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i *
|
||||||
}
|
// 16);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
};
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple container of multiple Tile16x16.
|
* A simple container of multiple Tile16x16.
|
||||||
|
|||||||
@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
|
|||||||
instantiate_and_or(or, Or)
|
instantiate_and_or(or, Or)
|
||||||
|
|
||||||
#define instantiate_sum_prod(name, op) \
|
#define instantiate_sum_prod(name, op) \
|
||||||
|
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||||
|
|||||||
@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
|
|||||||
const std::string& op_name) {
|
const std::string& op_name) {
|
||||||
if (op_name == "sum" || op_name == "prod") {
|
if (op_name == "sum" || op_name == "prod") {
|
||||||
if (issubdtype(in.dtype(), integer)) {
|
if (issubdtype(in.dtype(), integer)) {
|
||||||
switch (in.dtype().size()) {
|
switch (in.dtype()) {
|
||||||
case 1:
|
case uint8:
|
||||||
|
return {uint8, uint32};
|
||||||
|
case uint16:
|
||||||
|
return {uint16, uint32};
|
||||||
|
case uint32:
|
||||||
|
return {uint32, uint32};
|
||||||
|
case uint64:
|
||||||
|
return {uint64, uint64};
|
||||||
|
case int8:
|
||||||
return {int8, int32};
|
return {int8, int32};
|
||||||
case 2:
|
case int16:
|
||||||
return {int16, int32};
|
return {int16, int32};
|
||||||
case 4:
|
case int32:
|
||||||
return {int32, int32};
|
return {int32, int32};
|
||||||
case 8:
|
case int64:
|
||||||
return {int64, int64};
|
return {int64, int64};
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported integer type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
|
|||||||
@@ -2,6 +2,6 @@
|
|||||||
requires = [
|
requires = [
|
||||||
"setuptools>=80",
|
"setuptools>=80",
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
"cmake>=3.25",
|
"cmake>=3.25,<4.1",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|||||||
@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
|
|||||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sum and prod overflow
|
||||||
|
{
|
||||||
|
auto a = full({256, 2, 2}, 1u, uint8);
|
||||||
|
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
|
||||||
|
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||||
|
|
||||||
|
a = full({65535, 2, 2}, 1u, uint16);
|
||||||
|
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
|
||||||
|
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test gpu reduce with axes") {
|
||||||
// reducing only some axes and irregular layouts
|
// reducing only some axes and irregular layouts
|
||||||
{
|
{
|
||||||
array a(1.0f);
|
array a(1.0f);
|
||||||
|
|||||||
@@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
|
|||||||
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test unsigned sum
|
||||||
|
{
|
||||||
|
const int num_elems = 1000;
|
||||||
|
|
||||||
|
auto x = astype(full({num_elems}, 255), uint8);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
|
||||||
|
|
||||||
|
x = astype(full({num_elems}, 65535), uint16);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
|
||||||
|
|
||||||
|
x = full({3, 3, 3}, 10000, uint32);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
|
||||||
|
|
||||||
|
x = full({3, 3, 3}, 10000, uint64);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
|
||||||
|
}
|
||||||
|
|
||||||
// Test prod
|
// Test prod
|
||||||
{
|
{
|
||||||
auto x = array({});
|
auto x = array({});
|
||||||
@@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
|
|||||||
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test unsigned prod
|
||||||
|
{
|
||||||
|
auto x = array({255, 255}, {2}, uint8);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
|
||||||
|
|
||||||
|
x = array({65535, 2}, {2}, uint16);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
|
||||||
|
|
||||||
|
x = array({100000, 2}, {2}, uint32);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
|
||||||
|
|
||||||
|
x = array({100000, 2}, {2}, uint64);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
|
||||||
|
}
|
||||||
|
|
||||||
// Test all
|
// Test all
|
||||||
{
|
{
|
||||||
auto x = array({});
|
auto x = array({});
|
||||||
|
|||||||
Reference in New Issue
Block a user