mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
60939d010c
...
0d68efd461
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d68efd461 | ||
|
|
f9e1a14135 | ||
|
|
d8e9ded928 |
@@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator()
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.95;
|
||||
memory_limit_ = total * 0.9;
|
||||
max_pool_size_ = memory_limit_;
|
||||
|
||||
int device_count = 0;
|
||||
@@ -176,7 +176,7 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
// Copy to managed here if the buffer is not on the right device
|
||||
if (buf->device != device) {
|
||||
if (buf->device >= 0 && buf->device != device) {
|
||||
copy_to_managed(*buf);
|
||||
}
|
||||
return Buffer{buf};
|
||||
@@ -219,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
if (buf->device >= 0) {
|
||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
||||
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
CHECK_CUDA_ERROR(cudaFree(buf->data));
|
||||
}
|
||||
delete buf;
|
||||
}
|
||||
|
||||
@@ -139,10 +139,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
uint32_t num_keys = keys.size() / 2;
|
||||
size_t num_keys = keys.size() / 2;
|
||||
|
||||
uint32_t elems_per_key = out.size() / num_keys;
|
||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
@@ -150,19 +150,25 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
uint32_t half_size = out_per_key / 2;
|
||||
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
size_t half_size = out_per_key / 2;
|
||||
|
||||
bool odd = out_per_key % 2;
|
||||
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
|
||||
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
|
||||
}
|
||||
|
||||
encoder.set_input_array(keys);
|
||||
encoder.set_output_array(out);
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
int64_t total = num_keys * (half_size + odd);
|
||||
uint32_t threads_y = 1;
|
||||
while ((total / threads_y) >= UINT_MAX) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
uint32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
|
||||
dim3 grid_dims{
|
||||
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
auto& stream = encoder.stream();
|
||||
if (keys.flags().row_contiguous) {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 30
|
||||
#define MLX_VERSION_PATCH 0
|
||||
#define MLX_VERSION_PATCH 1
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user