mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Matmul utils initial commit (#2441)
This commit is contained in:
committed by
GitHub
parent
86258f292f
commit
be9bc96da4
@@ -108,6 +108,7 @@ void Matmul::run_batched(
|
||||
cu::set_mm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
@@ -168,6 +169,7 @@ void Matmul::run_batched(
|
||||
cu::set_addmm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
|
||||
@@ -143,6 +143,7 @@ void gemv(
|
||||
kernel,
|
||||
num_blocks_x,
|
||||
block_dims,
|
||||
0,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
@@ -154,6 +155,7 @@ void gemv(
|
||||
kernel,
|
||||
dim3{num_blocks_x, batch_count},
|
||||
block_dims,
|
||||
0,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
|
||||
Reference in New Issue
Block a user