mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
format
This commit is contained in:
@@ -1,9 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
@@ -30,12 +27,8 @@ inline constexpr __device__ short get_bytes_per_pack() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
__global__ void affine_quantize(
|
__global__ void
|
||||||
const T* w,
|
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||||
uint8_t* out,
|
|
||||||
T* scales,
|
|
||||||
T* biases,
|
|
||||||
size_t size) {
|
|
||||||
auto block_size = cg::this_thread_block().dim_threads();
|
auto block_size = cg::this_thread_block().dim_threads();
|
||||||
auto block_idx = cg::this_thread_block().group_index();
|
auto block_idx = cg::this_thread_block().group_index();
|
||||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||||
@@ -139,9 +132,9 @@ __global__ void affine_quantize(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if constexpr (writes_per_reduce > 0) {
|
if constexpr (writes_per_reduce > 0) {
|
||||||
if (out_index % writes_per_reduce == 0) {
|
if (out_index % writes_per_reduce == 0) {
|
||||||
out[out_index / writes_per_reduce] = output;
|
out[out_index / writes_per_reduce] = output;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -153,7 +146,6 @@ __global__ void affine_dequantize(
|
|||||||
const T* biases,
|
const T* biases,
|
||||||
T* out,
|
T* out,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
|
|
||||||
auto block_size = cg::this_thread_block().dim_threads();
|
auto block_size = cg::this_thread_block().dim_threads();
|
||||||
auto block_idx = cg::this_thread_block().group_index();
|
auto block_idx = cg::this_thread_block().group_index();
|
||||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||||
@@ -224,8 +216,10 @@ __global__ void affine_dequantize(
|
|||||||
} // namespace cu
|
} // namespace cu
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline array
|
inline array ensure_row_contiguous(
|
||||||
ensure_row_contiguous(const array& x, cu::CommandEncoder& enc, const Stream& s) {
|
const array& x,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
if (!x.flags().row_contiguous) {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
|||||||
Reference in New Issue
Block a user