mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Use async cuda malloc managed with cuda 13
This commit is contained in:
		@@ -1,6 +1,7 @@
 | 
				
			|||||||
// Copyright © 2025 Apple Inc.
 | 
					// Copyright © 2025 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/backend/cuda/allocator.h"
 | 
					#include "mlx/backend/cuda/allocator.h"
 | 
				
			||||||
 | 
					#include "mlx/backend/cuda/device.h"
 | 
				
			||||||
#include "mlx/backend/cuda/utils.h"
 | 
					#include "mlx/backend/cuda/utils.h"
 | 
				
			||||||
#include "mlx/utils.h"
 | 
					#include "mlx/utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -93,9 +94,17 @@ CudaAllocator::CudaAllocator()
 | 
				
			|||||||
  CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
 | 
					  CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
 | 
				
			||||||
  memory_limit_ = total * 0.95;
 | 
					  memory_limit_ = total * 0.95;
 | 
				
			||||||
  max_pool_size_ = memory_limit_;
 | 
					  max_pool_size_ = memory_limit_;
 | 
				
			||||||
 | 
					#if CUDART_VERSION >= 13000
 | 
				
			||||||
 | 
					  cudaMemLocation loc;
 | 
				
			||||||
 | 
					  loc.id = 0;
 | 
				
			||||||
 | 
					  loc.type = cudaMemLocationTypeNone;
 | 
				
			||||||
 | 
					  cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged);
 | 
				
			||||||
 | 
					  // TODO set that.
 | 
				
			||||||
 | 
					  // uint64_t threshold = UINT64_MAX;
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Buffer CudaAllocator::malloc(size_t size) {
 | 
					Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
 | 
				
			||||||
  // Find available buffer from cache.
 | 
					  // Find available buffer from cache.
 | 
				
			||||||
  auto orig_size = size;
 | 
					  auto orig_size = size;
 | 
				
			||||||
  std::unique_lock lock(mutex_);
 | 
					  std::unique_lock lock(mutex_);
 | 
				
			||||||
@@ -123,7 +132,12 @@ Buffer CudaAllocator::malloc(size_t size) {
 | 
				
			|||||||
    lock.unlock();
 | 
					    lock.unlock();
 | 
				
			||||||
    if (!buf) {
 | 
					    if (!buf) {
 | 
				
			||||||
      buf = new CudaBuffer{nullptr, size};
 | 
					      buf = new CudaBuffer{nullptr, size};
 | 
				
			||||||
      cudaError_t err = cudaMallocManaged(&buf->data, size);
 | 
					      cudaError_t err;
 | 
				
			||||||
 | 
					      if (stream != nullptr && cuda_pool_ != nullptr) {
 | 
				
			||||||
 | 
					        err = cudaMallocFromPoolAsync(&buf->data, size, cuda_pool_, stream);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        err = cudaMallocManaged(&buf->data, size);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
 | 
					      if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
 | 
				
			||||||
        throw std::runtime_error(fmt::format(
 | 
					        throw std::runtime_error(fmt::format(
 | 
				
			||||||
            "cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
 | 
					            "cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
 | 
				
			||||||
@@ -141,6 +155,14 @@ Buffer CudaAllocator::malloc(size_t size) {
 | 
				
			|||||||
  return Buffer{buf};
 | 
					  return Buffer{buf};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
 | 
				
			||||||
 | 
					  return malloc_impl(size, stream);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Buffer CudaAllocator::malloc(size_t size) {
 | 
				
			||||||
 | 
					  return malloc_impl(size, nullptr);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void CudaAllocator::free(Buffer buffer) {
 | 
					void CudaAllocator::free(Buffer buffer) {
 | 
				
			||||||
  auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
 | 
					  auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
 | 
				
			||||||
  if (!buf) {
 | 
					  if (!buf) {
 | 
				
			||||||
@@ -220,6 +242,16 @@ CudaAllocator& allocator() {
 | 
				
			|||||||
  return *allocator_;
 | 
					  return *allocator_;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Buffer malloc_async(size_t size, cudaStream_t stream) {
 | 
				
			||||||
 | 
					  auto buffer = allocator().malloc_async(size, stream);
 | 
				
			||||||
 | 
					  if (size && !buffer.ptr()) {
 | 
				
			||||||
 | 
					    std::ostringstream msg;
 | 
				
			||||||
 | 
					    msg << "[malloc_async] Unable to allocate " << size << " bytes.";
 | 
				
			||||||
 | 
					    throw std::runtime_error(msg.str());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return buffer;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace cu
 | 
					} // namespace cu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace allocator {
 | 
					namespace allocator {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,6 +5,7 @@
 | 
				
			|||||||
#include "mlx/allocator.h"
 | 
					#include "mlx/allocator.h"
 | 
				
			||||||
#include "mlx/backend/common/buffer_cache.h"
 | 
					#include "mlx/backend/common/buffer_cache.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cuda_runtime.h>
 | 
				
			||||||
#include <mutex>
 | 
					#include <mutex>
 | 
				
			||||||
#include <set>
 | 
					#include <set>
 | 
				
			||||||
#include <utility>
 | 
					#include <utility>
 | 
				
			||||||
@@ -45,6 +46,7 @@ class SmallSizePool {
 | 
				
			|||||||
class CudaAllocator : public allocator::Allocator {
 | 
					class CudaAllocator : public allocator::Allocator {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  Buffer malloc(size_t size) override;
 | 
					  Buffer malloc(size_t size) override;
 | 
				
			||||||
 | 
					  Buffer malloc_async(size_t size, cudaStream_t stream);
 | 
				
			||||||
  void free(Buffer buffer) override;
 | 
					  void free(Buffer buffer) override;
 | 
				
			||||||
  size_t size(Buffer buffer) const override;
 | 
					  size_t size(Buffer buffer) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -58,6 +60,7 @@ class CudaAllocator : public allocator::Allocator {
 | 
				
			|||||||
  void clear_cache();
 | 
					  void clear_cache();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
 | 
					  Buffer malloc_impl(size_t size, cudaStream_t stream);
 | 
				
			||||||
  void cuda_free(CudaBuffer* buf);
 | 
					  void cuda_free(CudaBuffer* buf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  CudaAllocator();
 | 
					  CudaAllocator();
 | 
				
			||||||
@@ -70,8 +73,11 @@ class CudaAllocator : public allocator::Allocator {
 | 
				
			|||||||
  size_t active_memory_{0};
 | 
					  size_t active_memory_{0};
 | 
				
			||||||
  size_t peak_memory_{0};
 | 
					  size_t peak_memory_{0};
 | 
				
			||||||
  SmallSizePool scalar_pool_;
 | 
					  SmallSizePool scalar_pool_;
 | 
				
			||||||
 | 
					  cudaMemPool_t cuda_pool_{nullptr};
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CudaAllocator& allocator();
 | 
					CudaAllocator& allocator();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Buffer malloc_async(size_t size, cudaStream_t stream);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core::cu
 | 
					} // namespace mlx::core::cu
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  if (out.size() == 0) {
 | 
					  if (out.size() == 0) {
 | 
				
			||||||
    return;
 | 
					    return;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(stream());
 | 
					  auto& encoder = cu::get_command_encoder(stream());
 | 
				
			||||||
 | 
					  out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
  encoder.set_output_array(out);
 | 
					  encoder.set_output_array(out);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
 | 
					  dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  nvtx3::scoped_range r("ArgReduce::eval_gpu");
 | 
					  nvtx3::scoped_range r("ArgReduce::eval_gpu");
 | 
				
			||||||
  assert(inputs.size() == 1);
 | 
					  assert(inputs.size() == 1);
 | 
				
			||||||
  auto& in = inputs[0];
 | 
					  auto& in = inputs[0];
 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					
 | 
				
			||||||
  auto& s = stream();
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					  out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Prepare the shapes, strides and axis arguments.
 | 
					  // Prepare the shapes, strides and axis arguments.
 | 
				
			||||||
  Shape shape = remove_index(in.shape(), axis_);
 | 
					  Shape shape = remove_index(in.shape(), axis_);
 | 
				
			||||||
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  int32_t ndim = shape.size();
 | 
					  int32_t ndim = shape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // ArgReduce.
 | 
					  // ArgReduce.
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  encoder.set_input_array(in);
 | 
					  encoder.set_input_array(in);
 | 
				
			||||||
  encoder.set_output_array(out);
 | 
					  encoder.set_output_array(out);
 | 
				
			||||||
  dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
 | 
					  dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -87,8 +87,8 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
 | 
				
			|||||||
  if (out.size() == 0) {
 | 
					  if (out.size() == 0) {
 | 
				
			||||||
    return;
 | 
					    return;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					  out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
  encoder.set_input_array(in);
 | 
					  encoder.set_input_array(in);
 | 
				
			||||||
  encoder.set_output_array(out);
 | 
					  encoder.set_output_array(out);
 | 
				
			||||||
  copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
 | 
					  copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@
 | 
				
			|||||||
#pragma once
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/array.h"
 | 
					#include "mlx/array.h"
 | 
				
			||||||
 | 
					#include "mlx/backend/cuda/allocator.h"
 | 
				
			||||||
#include "mlx/backend/cuda/lru_cache.h"
 | 
					#include "mlx/backend/cuda/lru_cache.h"
 | 
				
			||||||
#include "mlx/backend/cuda/worker.h"
 | 
					#include "mlx/backend/cuda/worker.h"
 | 
				
			||||||
#include "mlx/stream.h"
 | 
					#include "mlx/stream.h"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -370,7 +370,7 @@ void CublasGemm::execute(
 | 
				
			|||||||
    // Ensure workspace is 256-byte aligned
 | 
					    // Ensure workspace is 256-byte aligned
 | 
				
			||||||
    int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
 | 
					    int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
 | 
				
			||||||
    array workspace(
 | 
					    array workspace(
 | 
				
			||||||
        allocator::malloc(nbytes),
 | 
					        cu::malloc_async(nbytes, encoder.stream()),
 | 
				
			||||||
        {static_cast<int>(heuristic_.workspaceSize)},
 | 
					        {static_cast<int>(heuristic_.workspaceSize)},
 | 
				
			||||||
        int8);
 | 
					        int8);
 | 
				
			||||||
    encoder.add_temporary(workspace);
 | 
					    encoder.add_temporary(workspace);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Launch kernel to set device offsets
 | 
					  // Launch kernel to set device offsets
 | 
				
			||||||
  auto pointers = array(
 | 
					  auto pointers = array(
 | 
				
			||||||
      allocator::malloc(batch_count * sizeof(void*) * 3),
 | 
					      cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
 | 
				
			||||||
      {batch_count * 3},
 | 
					      {batch_count * 3},
 | 
				
			||||||
      uint64);
 | 
					      uint64);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Launch kernel to set device offsets
 | 
					  // Launch kernel to set device offsets
 | 
				
			||||||
  auto pointers = array(
 | 
					  auto pointers = array(
 | 
				
			||||||
      allocator::malloc(batch_count * sizeof(uint64_t) * 4),
 | 
					      cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
 | 
				
			||||||
      {batch_count * 4},
 | 
					      {batch_count * 4},
 | 
				
			||||||
      uint64);
 | 
					      uint64);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  assert(inputs.size() > 0);
 | 
					  assert(inputs.size() > 0);
 | 
				
			||||||
  const auto& src = inputs[0];
 | 
					  const auto& src = inputs[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					  out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
  if (out.size() == 0) {
 | 
					  if (out.size() == 0) {
 | 
				
			||||||
    return;
 | 
					    return;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
      dtype_to_string(idx_dtype),
 | 
					      dtype_to_string(idx_dtype),
 | 
				
			||||||
      nidx);
 | 
					      nidx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& s = stream();
 | 
					 | 
				
			||||||
  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
 | 
					  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
 | 
				
			||||||
    std::vector<std::string> kernel_names;
 | 
					    std::vector<std::string> kernel_names;
 | 
				
			||||||
    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
 | 
					    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
 | 
				
			||||||
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
      idx_ndim,
 | 
					      idx_ndim,
 | 
				
			||||||
      large ? "int64_t" : "int32_t");
 | 
					      large ? "int64_t" : "int32_t");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  for (const auto& in : inputs) {
 | 
					  for (const auto& in : inputs) {
 | 
				
			||||||
    encoder.set_input_array(in);
 | 
					    encoder.set_input_array(in);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  const auto& src = inputs[0];
 | 
					  const auto& src = inputs[0];
 | 
				
			||||||
  const auto& idx = inputs[1];
 | 
					  const auto& idx = inputs[1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					  out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
  if (out.size() == 0) {
 | 
					  if (out.size() == 0) {
 | 
				
			||||||
    return;
 | 
					    return;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
      dtype_to_string(out.dtype()),
 | 
					      dtype_to_string(out.dtype()),
 | 
				
			||||||
      dtype_to_string(idx.dtype()));
 | 
					      dtype_to_string(idx.dtype()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& s = stream();
 | 
					 | 
				
			||||||
  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
 | 
					  cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
 | 
				
			||||||
    std::vector<std::string> kernel_names;
 | 
					    std::vector<std::string> kernel_names;
 | 
				
			||||||
    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
 | 
					    for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
 | 
				
			||||||
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
      idx.flags().row_contiguous,
 | 
					      idx.flags().row_contiguous,
 | 
				
			||||||
      large ? "int64_t" : "int32_t");
 | 
					      large ? "int64_t" : "int32_t");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  for (const auto& in : inputs) {
 | 
					  for (const auto& in : inputs) {
 | 
				
			||||||
    encoder.set_input_array(in);
 | 
					    encoder.set_input_array(in);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
 | 
				
			|||||||
  nvtx3::scoped_range r("LayerNorm::eval_gpu");
 | 
					  nvtx3::scoped_range r("LayerNorm::eval_gpu");
 | 
				
			||||||
  auto& s = stream();
 | 
					  auto& s = stream();
 | 
				
			||||||
  auto& out = outputs[0];
 | 
					  auto& out = outputs[0];
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Make sure that the last dimension is contiguous.
 | 
					  // Make sure that the last dimension is contiguous.
 | 
				
			||||||
  auto set_output = [&s, &out](const array& x) {
 | 
					  auto set_output = [&s, &out, &encoder](const array& x) {
 | 
				
			||||||
    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
 | 
					    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
 | 
				
			||||||
    if (no_copy && x.ndim() > 1) {
 | 
					    if (no_copy && x.ndim() > 1) {
 | 
				
			||||||
      auto s = x.strides()[x.ndim() - 2];
 | 
					      auto s = x.strides()[x.ndim() - 2];
 | 
				
			||||||
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
 | 
				
			|||||||
        out.copy_shared_buffer(x);
 | 
					        out.copy_shared_buffer(x);
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        out.set_data(
 | 
					        out.set_data(
 | 
				
			||||||
            allocator::malloc(x.data_size() * x.itemsize()),
 | 
					            cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
 | 
				
			||||||
            x.data_size(),
 | 
					            x.data_size(),
 | 
				
			||||||
            x.strides(),
 | 
					            x.strides(),
 | 
				
			||||||
            x.flags());
 | 
					            x.flags());
 | 
				
			||||||
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
 | 
				
			|||||||
  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
 | 
					  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
 | 
				
			||||||
  int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
 | 
					  int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  encoder.set_input_array(x);
 | 
					  encoder.set_input_array(x);
 | 
				
			||||||
  encoder.set_input_array(w);
 | 
					  encoder.set_input_array(w);
 | 
				
			||||||
  encoder.set_input_array(b);
 | 
					  encoder.set_input_array(b);
 | 
				
			||||||
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
 | 
				
			|||||||
    gx.copy_shared_buffer(g);
 | 
					    gx.copy_shared_buffer(g);
 | 
				
			||||||
    g_in_gx = true;
 | 
					    g_in_gx = true;
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    gx.set_data(allocator::malloc(gx.nbytes()));
 | 
					    gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (g_copied && !g_in_gx) {
 | 
					  if (g_copied && !g_in_gx) {
 | 
				
			||||||
    encoder.add_temporary(g);
 | 
					    encoder.add_temporary(g);
 | 
				
			||||||
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
 | 
				
			|||||||
      g_in_gw = true;
 | 
					      g_in_gw = true;
 | 
				
			||||||
      gw_temp.copy_shared_buffer(g);
 | 
					      gw_temp.copy_shared_buffer(g);
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
 | 
					      gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
 | 
				
			||||||
      encoder.add_temporary(gw_temp);
 | 
					      encoder.add_temporary(gw_temp);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  auto in = ensure_contiguous(inputs[0]);
 | 
					  auto in = ensure_contiguous(inputs[0]);
 | 
				
			||||||
  if (in.flags().row_contiguous) {
 | 
					  if (in.flags().row_contiguous) {
 | 
				
			||||||
    out.set_data(allocator::malloc(out.nbytes()));
 | 
					    out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    auto n = in.shape(-1);
 | 
					    auto n = in.shape(-1);
 | 
				
			||||||
    auto flags = in.flags();
 | 
					    auto flags = in.flags();
 | 
				
			||||||
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    flags.col_contiguous = col_contig;
 | 
					    flags.col_contiguous = col_contig;
 | 
				
			||||||
    out.set_data(
 | 
					    out.set_data(
 | 
				
			||||||
        allocator::malloc(in.nbytes() / n),
 | 
					        cu::malloc_async(in.nbytes() / n, encoder.stream()),
 | 
				
			||||||
        in.data_size() / n,
 | 
					        in.data_size() / n,
 | 
				
			||||||
        std::move(strides),
 | 
					        std::move(strides),
 | 
				
			||||||
        flags);
 | 
					        flags);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
    return;
 | 
					    return;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					  out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int M = a_pre.shape(-2);
 | 
					  int M = a_pre.shape(-2);
 | 
				
			||||||
  int N = b_pre.shape(-1);
 | 
					  int N = b_pre.shape(-1);
 | 
				
			||||||
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
 | 
					  if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
 | 
				
			||||||
      c.data_size() == out.shape(-1)) {
 | 
					      c.data_size() == out.shape(-1)) {
 | 
				
			||||||
    out.set_data(allocator::malloc(out.nbytes()));
 | 
					    out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
    gemm_and_bias(
 | 
					    gemm_and_bias(
 | 
				
			||||||
        encoder,
 | 
					        encoder,
 | 
				
			||||||
        M,
 | 
					        M,
 | 
				
			||||||
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
    auto sty = c.strides()[c.ndim() - 1];
 | 
					    auto sty = c.strides()[c.ndim() - 1];
 | 
				
			||||||
    if (sty == 1 && stx == c.shape(-1)) {
 | 
					    if (sty == 1 && stx == c.shape(-1)) {
 | 
				
			||||||
      ldc = stx;
 | 
					      ldc = stx;
 | 
				
			||||||
      out.set_data(allocator::malloc(out.nbytes()));
 | 
					      out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
    } else if (sty == 1 && stx == 0) {
 | 
					    } else if (sty == 1 && stx == 0) {
 | 
				
			||||||
      ldc = 0;
 | 
					      ldc = 0;
 | 
				
			||||||
      out.set_data(allocator::malloc(out.nbytes()));
 | 
					      out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      // Copy C into out and set C to out
 | 
					      // Copy C into out and set C to out
 | 
				
			||||||
      ldc = c.shape(-1);
 | 
					      ldc = c.shape(-1);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -176,9 +176,10 @@ void RMSNorm::eval_gpu(
 | 
				
			|||||||
  nvtx3::scoped_range r("RMSNorm::eval_gpu");
 | 
					  nvtx3::scoped_range r("RMSNorm::eval_gpu");
 | 
				
			||||||
  auto& s = stream();
 | 
					  auto& s = stream();
 | 
				
			||||||
  auto& out = outputs[0];
 | 
					  auto& out = outputs[0];
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Make sure that the last dimension is contiguous.
 | 
					  // Make sure that the last dimension is contiguous.
 | 
				
			||||||
  auto set_output = [&s, &out](const array& x) {
 | 
					  auto set_output = [&s, &out, &encoder](const array& x) {
 | 
				
			||||||
    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
 | 
					    bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
 | 
				
			||||||
    if (no_copy && x.ndim() > 1) {
 | 
					    if (no_copy && x.ndim() > 1) {
 | 
				
			||||||
      auto s = x.strides()[x.ndim() - 2];
 | 
					      auto s = x.strides()[x.ndim() - 2];
 | 
				
			||||||
@@ -189,7 +190,7 @@ void RMSNorm::eval_gpu(
 | 
				
			|||||||
        out.copy_shared_buffer(x);
 | 
					        out.copy_shared_buffer(x);
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        out.set_data(
 | 
					        out.set_data(
 | 
				
			||||||
            allocator::malloc(x.data_size() * x.itemsize()),
 | 
					            cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
 | 
				
			||||||
            x.data_size(),
 | 
					            x.data_size(),
 | 
				
			||||||
            x.strides(),
 | 
					            x.strides(),
 | 
				
			||||||
            x.flags());
 | 
					            x.flags());
 | 
				
			||||||
@@ -209,7 +210,6 @@ void RMSNorm::eval_gpu(
 | 
				
			|||||||
  int32_t n_rows = x.data_size() / axis_size;
 | 
					  int32_t n_rows = x.data_size() / axis_size;
 | 
				
			||||||
  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
 | 
					  int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  encoder.set_input_array(x);
 | 
					  encoder.set_input_array(x);
 | 
				
			||||||
  encoder.set_input_array(w);
 | 
					  encoder.set_input_array(w);
 | 
				
			||||||
  encoder.set_output_array(out);
 | 
					  encoder.set_output_array(out);
 | 
				
			||||||
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
 | 
				
			|||||||
    gx.copy_shared_buffer(g);
 | 
					    gx.copy_shared_buffer(g);
 | 
				
			||||||
    g_in_gx = true;
 | 
					    g_in_gx = true;
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    gx.set_data(allocator::malloc(gx.nbytes()));
 | 
					    gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (g_copied && !g_in_gx) {
 | 
					  if (g_copied && !g_in_gx) {
 | 
				
			||||||
    encoder.add_temporary(g);
 | 
					    encoder.add_temporary(g);
 | 
				
			||||||
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
 | 
				
			|||||||
    if (!g_in_gx && donate_g) {
 | 
					    if (!g_in_gx && donate_g) {
 | 
				
			||||||
      gw_temp.copy_shared_buffer(g);
 | 
					      gw_temp.copy_shared_buffer(g);
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
 | 
					      gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
 | 
				
			||||||
      encoder.add_temporary(gw_temp);
 | 
					      encoder.add_temporary(gw_temp);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -250,6 +250,7 @@ void RoPE::eval_gpu(
 | 
				
			|||||||
  nvtx3::scoped_range r("RoPE::eval_gpu");
 | 
					  nvtx3::scoped_range r("RoPE::eval_gpu");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& s = stream();
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
  auto& in = inputs[0];
 | 
					  auto& in = inputs[0];
 | 
				
			||||||
  auto& offset = inputs[1];
 | 
					  auto& offset = inputs[1];
 | 
				
			||||||
  auto& out = outputs[0];
 | 
					  auto& out = outputs[0];
 | 
				
			||||||
@@ -291,14 +292,14 @@ void RoPE::eval_gpu(
 | 
				
			|||||||
      donated = true;
 | 
					      donated = true;
 | 
				
			||||||
      out.copy_shared_buffer(in);
 | 
					      out.copy_shared_buffer(in);
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      out.set_data(allocator::malloc(out.nbytes()));
 | 
					      out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    strides[0] = mat_size;
 | 
					    strides[0] = mat_size;
 | 
				
			||||||
    strides[1] = in.strides()[ndim - 2];
 | 
					    strides[1] = in.strides()[ndim - 2];
 | 
				
			||||||
    strides[2] = in.strides()[ndim - 1];
 | 
					    strides[2] = in.strides()[ndim - 1];
 | 
				
			||||||
  } else if (dispatch_ndim == 3) {
 | 
					  } else if (dispatch_ndim == 3) {
 | 
				
			||||||
    // Handle non-contiguous 3D inputs
 | 
					    // Handle non-contiguous 3D inputs
 | 
				
			||||||
    out.set_data(allocator::malloc(out.nbytes()));
 | 
					    out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
    strides[0] = in.strides()[ndim - 3];
 | 
					    strides[0] = in.strides()[ndim - 3];
 | 
				
			||||||
    strides[1] = in.strides()[ndim - 2];
 | 
					    strides[1] = in.strides()[ndim - 2];
 | 
				
			||||||
    strides[2] = in.strides()[ndim - 1];
 | 
					    strides[2] = in.strides()[ndim - 1];
 | 
				
			||||||
@@ -319,7 +320,6 @@ void RoPE::eval_gpu(
 | 
				
			|||||||
  bool single = in.flags().row_contiguous && B == 1 && T == 1;
 | 
					  bool single = in.flags().row_contiguous && B == 1 && T == 1;
 | 
				
			||||||
  bool with_freqs = inputs.size() == 3;
 | 
					  bool with_freqs = inputs.size() == 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  encoder.set_input_array(donated ? out : in);
 | 
					  encoder.set_input_array(donated ? out : in);
 | 
				
			||||||
  encoder.set_input_array(offset);
 | 
					  encoder.set_input_array(offset);
 | 
				
			||||||
  if (with_freqs) {
 | 
					  if (with_freqs) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -565,9 +565,10 @@ void sdpa_vector_2pass_fallback(
 | 
				
			|||||||
  array sums(intermediate_shape, float32, nullptr, {});
 | 
					  array sums(intermediate_shape, float32, nullptr, {});
 | 
				
			||||||
  array maxs(std::move(intermediate_shape), float32, nullptr, {});
 | 
					  array maxs(std::move(intermediate_shape), float32, nullptr, {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  intermediate.set_data(allocator::malloc(intermediate.nbytes()));
 | 
					  intermediate.set_data(
 | 
				
			||||||
  sums.set_data(allocator::malloc(sums.nbytes()));
 | 
					      cu::malloc_async(intermediate.nbytes(), encoder.stream()));
 | 
				
			||||||
  maxs.set_data(allocator::malloc(maxs.nbytes()));
 | 
					  sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
 | 
				
			||||||
 | 
					  maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  encoder.add_temporary(intermediate);
 | 
					  encoder.add_temporary(intermediate);
 | 
				
			||||||
  encoder.add_temporary(sums);
 | 
					  encoder.add_temporary(sums);
 | 
				
			||||||
@@ -787,7 +788,7 @@ void ScaledDotProductAttention::eval_gpu(
 | 
				
			|||||||
      };
 | 
					      };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      o.set_data(
 | 
					      o.set_data(
 | 
				
			||||||
          allocator::malloc(o.nbytes()),
 | 
					          cu::malloc_async(o.nbytes(), encoder.stream()),
 | 
				
			||||||
          o.size(),
 | 
					          o.size(),
 | 
				
			||||||
          {str_oB, str_oH, str_oL, str_oD},
 | 
					          {str_oB, str_oH, str_oL, str_oD},
 | 
				
			||||||
          flags);
 | 
					          flags);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -367,13 +367,14 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  assert(inputs.size() == 1);
 | 
					  assert(inputs.size() == 1);
 | 
				
			||||||
  auto in = inputs[0];
 | 
					  auto in = inputs[0];
 | 
				
			||||||
  auto& s = stream();
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (in.flags().contiguous && in.strides()[axis_] != 0) {
 | 
					  if (in.flags().contiguous && in.strides()[axis_] != 0) {
 | 
				
			||||||
    if (in.is_donatable() && in.itemsize() == out.itemsize()) {
 | 
					    if (in.is_donatable() && in.itemsize() == out.itemsize()) {
 | 
				
			||||||
      out.copy_shared_buffer(in);
 | 
					      out.copy_shared_buffer(in);
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      out.set_data(
 | 
					      out.set_data(
 | 
				
			||||||
          allocator::malloc(in.data_size() * out.itemsize()),
 | 
					          cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
 | 
				
			||||||
          in.data_size(),
 | 
					          in.data_size(),
 | 
				
			||||||
          in.strides(),
 | 
					          in.strides(),
 | 
				
			||||||
          in.flags());
 | 
					          in.flags());
 | 
				
			||||||
@@ -387,7 +388,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  int32_t axis_size = in.shape(axis_);
 | 
					  int32_t axis_size = in.shape(axis_);
 | 
				
			||||||
  bool contiguous = in.strides()[axis_] == 1;
 | 
					  bool contiguous = in.strides()[axis_] == 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  encoder.set_input_array(in);
 | 
					  encoder.set_input_array(in);
 | 
				
			||||||
  encoder.set_output_array(out);
 | 
					  encoder.set_output_array(out);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,14 +23,15 @@ void concatenate_gpu(
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
  std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
 | 
					  std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					  out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto strides = out.strides();
 | 
					  auto strides = out.strides();
 | 
				
			||||||
  auto flags = out.flags();
 | 
					  auto flags = out.flags();
 | 
				
			||||||
  flags.row_contiguous = false;
 | 
					  flags.row_contiguous = false;
 | 
				
			||||||
  flags.col_contiguous = false;
 | 
					  flags.col_contiguous = false;
 | 
				
			||||||
  flags.contiguous = false;
 | 
					  flags.contiguous = false;
 | 
				
			||||||
  auto concurrent = cu::get_command_encoder(s).concurrent_context();
 | 
					  auto concurrent = encoder.concurrent_context();
 | 
				
			||||||
  for (int i = 0; i < inputs.size(); i++) {
 | 
					  for (int i = 0; i < inputs.size(); i++) {
 | 
				
			||||||
    array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
 | 
					    array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
 | 
				
			||||||
    size_t data_offset = strides[axis] * sizes[i];
 | 
					    size_t data_offset = strides[axis] * sizes[i];
 | 
				
			||||||
@@ -80,6 +81,7 @@ array compute_dynamic_offset(
 | 
				
			|||||||
    return std::make_tuple(false, std::move(source), std::vector{kernel_name});
 | 
					    return std::make_tuple(false, std::move(source), std::vector{kernel_name});
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
  // Prepare output.
 | 
					  // Prepare output.
 | 
				
			||||||
  array offset({1}, int64, nullptr, {});
 | 
					  array offset({1}, int64, nullptr, {});
 | 
				
			||||||
  bool donate = indices.is_donatable() &&
 | 
					  bool donate = indices.is_donatable() &&
 | 
				
			||||||
@@ -87,10 +89,9 @@ array compute_dynamic_offset(
 | 
				
			|||||||
  if (donate) {
 | 
					  if (donate) {
 | 
				
			||||||
    offset.copy_shared_buffer(indices);
 | 
					    offset.copy_shared_buffer(indices);
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    offset.set_data(allocator::malloc(offset.itemsize()));
 | 
					    offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  encoder.add_temporary(offset);
 | 
					  encoder.add_temporary(offset);
 | 
				
			||||||
  encoder.set_input_array(indices);
 | 
					  encoder.set_input_array(indices);
 | 
				
			||||||
  encoder.set_output_array(offset);
 | 
					  encoder.set_output_array(offset);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -109,15 +109,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  nvtx3::scoped_range r("Softmax::eval_gpu");
 | 
					  nvtx3::scoped_range r("Softmax::eval_gpu");
 | 
				
			||||||
  assert(inputs.size() == 1);
 | 
					  assert(inputs.size() == 1);
 | 
				
			||||||
  auto& s = stream();
 | 
					  auto& s = stream();
 | 
				
			||||||
 | 
					  auto& encoder = cu::get_command_encoder(s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Make sure that the last dimension is contiguous.
 | 
					  // Make sure that the last dimension is contiguous.
 | 
				
			||||||
  auto set_output = [&s, &out](const array& x) {
 | 
					  auto set_output = [&s, &out, &encoder](const array& x) {
 | 
				
			||||||
    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
 | 
					    if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
 | 
				
			||||||
      if (x.is_donatable()) {
 | 
					      if (x.is_donatable()) {
 | 
				
			||||||
        out.copy_shared_buffer(x);
 | 
					        out.copy_shared_buffer(x);
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        out.set_data(
 | 
					        out.set_data(
 | 
				
			||||||
            allocator::malloc(x.data_size() * x.itemsize()),
 | 
					            cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
 | 
				
			||||||
            x.data_size(),
 | 
					            x.data_size(),
 | 
				
			||||||
            x.strides(),
 | 
					            x.strides(),
 | 
				
			||||||
            x.flags());
 | 
					            x.flags());
 | 
				
			||||||
@@ -136,7 +137,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  int axis_size = in.shape().back();
 | 
					  int axis_size = in.shape().back();
 | 
				
			||||||
  int n_rows = in.data_size() / axis_size;
 | 
					  int n_rows = in.data_size() / axis_size;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& encoder = cu::get_command_encoder(s);
 | 
					 | 
				
			||||||
  encoder.set_input_array(in);
 | 
					  encoder.set_input_array(in);
 | 
				
			||||||
  encoder.set_output_array(out);
 | 
					  encoder.set_output_array(out);
 | 
				
			||||||
  dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
 | 
					  dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -49,11 +49,14 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
 | 
				
			|||||||
    array trans = swapaxes_in_eval(in, axis, last_dim);
 | 
					    array trans = swapaxes_in_eval(in, axis, last_dim);
 | 
				
			||||||
    in = contiguous_copy_gpu(trans, s);
 | 
					    in = contiguous_copy_gpu(trans, s);
 | 
				
			||||||
    encoder.add_temporary(in);
 | 
					    encoder.add_temporary(in);
 | 
				
			||||||
    out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
 | 
					    out = array(
 | 
				
			||||||
 | 
					        cu::malloc_async(out.nbytes(), encoder.stream()),
 | 
				
			||||||
 | 
					        in.shape(),
 | 
				
			||||||
 | 
					        out.dtype());
 | 
				
			||||||
    encoder.add_temporary(out);
 | 
					    encoder.add_temporary(out);
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    out.set_data(
 | 
					    out.set_data(
 | 
				
			||||||
        allocator::malloc(in.data_size() * out.itemsize()),
 | 
					        cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
 | 
				
			||||||
        in.data_size(),
 | 
					        in.data_size(),
 | 
				
			||||||
        in.strides(),
 | 
					        in.strides(),
 | 
				
			||||||
        in.flags());
 | 
					        in.flags());
 | 
				
			||||||
@@ -70,12 +73,18 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
 | 
				
			|||||||
          thrust::make_counting_iterator(0), OffsetTransform{nsort});
 | 
					          thrust::make_counting_iterator(0), OffsetTransform{nsort});
 | 
				
			||||||
      if (argsort) {
 | 
					      if (argsort) {
 | 
				
			||||||
        // Indices in the sorted dimension.
 | 
					        // Indices in the sorted dimension.
 | 
				
			||||||
        array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
 | 
					        array indices(
 | 
				
			||||||
 | 
					            cu::malloc_async(out.nbytes(), encoder.stream()),
 | 
				
			||||||
 | 
					            in.shape(),
 | 
				
			||||||
 | 
					            out.dtype());
 | 
				
			||||||
        encoder.add_temporary(indices);
 | 
					        encoder.add_temporary(indices);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // In argsort though we don't need the result of sorted values, the
 | 
					        // In argsort though we don't need the result of sorted values, the
 | 
				
			||||||
        // API requires us to provide an array to store it.
 | 
					        // API requires us to provide an array to store it.
 | 
				
			||||||
        array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
 | 
					        array discard(
 | 
				
			||||||
 | 
					            cu::malloc_async(in.nbytes(), encoder.stream()),
 | 
				
			||||||
 | 
					            in.shape(),
 | 
				
			||||||
 | 
					            in.dtype());
 | 
				
			||||||
        encoder.add_temporary(discard);
 | 
					        encoder.add_temporary(discard);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        size_t size;
 | 
					        size_t size;
 | 
				
			||||||
@@ -94,7 +103,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
 | 
				
			|||||||
            sizeof(Type) * 8,
 | 
					            sizeof(Type) * 8,
 | 
				
			||||||
            stream));
 | 
					            stream));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
 | 
					        array temp(
 | 
				
			||||||
 | 
					            cu::malloc_async(size, encoder.stream()),
 | 
				
			||||||
 | 
					            {static_cast<int>(size)},
 | 
				
			||||||
 | 
					            uint8);
 | 
				
			||||||
        encoder.add_temporary(temp);
 | 
					        encoder.add_temporary(temp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Start capturing after allocations
 | 
					        // Start capturing after allocations
 | 
				
			||||||
@@ -135,7 +147,10 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
 | 
				
			|||||||
            sizeof(Type) * 8,
 | 
					            sizeof(Type) * 8,
 | 
				
			||||||
            stream));
 | 
					            stream));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
 | 
					        array temp(
 | 
				
			||||||
 | 
					            cu::malloc_async(size, encoder.stream()),
 | 
				
			||||||
 | 
					            {static_cast<int>(size)},
 | 
				
			||||||
 | 
					            uint8);
 | 
				
			||||||
        encoder.add_temporary(temp);
 | 
					        encoder.add_temporary(temp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Start capturing after allocations
 | 
					        // Start capturing after allocations
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user