Compare commits

..

No commits in common. "a6d780154f2fe79e893045659d17fbace243802a" and "8402a2acf4325a7213211dd7fcb4f397981ca695" have entirely different histories.

2 changed files with 20 additions and 9 deletions

View File

@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) {
} }
} }
#define INCLUDE_PREFIX "mlx/backend/cuda/device/" #define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
constexpr const char* g_include_names[] = { constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "atomic_ops.cuh",

View File

@ -44,12 +44,9 @@ class MatMul {
int64_t b_batch_stride) { int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type)); &matmul_desc_, dtype_to_compute_type(dtype), type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
@ -68,7 +65,6 @@ class MatMul {
&op, &op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout( a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout( b_desc_ = create_matrix_layout(
@ -191,10 +187,15 @@ class MatMul {
private: private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) { cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) { switch (dtype) {
case uint8:
case uint16:
case int8:
case int16:
case int32:
return CUBLAS_COMPUTE_32I;
case float16: case float16:
return CUBLAS_COMPUTE_16F;
case bfloat16: case bfloat16:
return CUBLAS_COMPUTE_32F; return CUBLAS_COMPUTE_16F;
case float32: case float32:
return CUBLAS_COMPUTE_32F; return CUBLAS_COMPUTE_32F;
case float64: case float64:
@ -208,6 +209,16 @@ class MatMul {
cudaDataType_t dtype_to_cuda_type(Dtype dtype) { cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) { switch (dtype) {
case uint8:
return CUDA_R_8U;
case uint16:
return CUDA_R_16U;
case int8:
return CUDA_R_8I;
case int16:
return CUDA_R_16I;
case int32:
return CUDA_R_32I;
case float16: case float16:
return CUDA_R_16F; return CUDA_R_16F;
case bfloat16: case bfloat16: