mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -976,7 +976,9 @@ void fast::AffineQuantize::eval_gpu(
|
||||
// Treat uint32 as uint8 in kernel
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
constexpr int simd_size = 32;
|
||||
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
|
||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||
: bits_ == 6 ? 4
|
||||
: 8 / bits_;
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
|
||||
size_t nthreads =
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
|
||||
Reference in New Issue
Block a user