mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 22:27:50 +08:00
Compare commits
No commits in common. "a6d780154f2fe79e893045659d17fbace243802a" and "8402a2acf4325a7213211dd7fcb4f397981ca695" have entirely different histories.
a6d780154f
...
8402a2acf4
@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) {
|
||||
}
|
||||
}
|
||||
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/device/"
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
|
||||
|
||||
constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||
|
@ -44,12 +44,9 @@ class MatMul {
|
||||
int64_t b_batch_stride) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
if (dtype == bfloat16) {
|
||||
scale_type = CUDA_R_32F;
|
||||
}
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
@ -68,7 +65,6 @@ class MatMul {
|
||||
&op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
@ -191,10 +187,15 @@ class MatMul {
|
||||
private:
|
||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
case uint16:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
return CUBLAS_COMPUTE_32I;
|
||||
case float16:
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
case float32:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
@ -208,6 +209,16 @@ class MatMul {
|
||||
|
||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
return CUDA_R_8U;
|
||||
case uint16:
|
||||
return CUDA_R_16U;
|
||||
case int8:
|
||||
return CUDA_R_8I;
|
||||
case int16:
|
||||
return CUDA_R_16I;
|
||||
case int32:
|
||||
return CUDA_R_32I;
|
||||
case float16:
|
||||
return CUDA_R_16F;
|
||||
case bfloat16:
|
||||
|
Loading…
Reference in New Issue
Block a user