Compare commits

...

4 Commits

Author SHA1 Message Date
Cheng
56cc858af9 Add contiguous_copy_cpu util for copying array (#2397) 2025-07-21 07:30:35 -07:00
Cheng
f55c4ed1d6 Remove thrust iterators (#2396) 2025-07-21 07:30:27 -07:00
Awni Hannun
93d70419e7 [CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda

* comment
2025-07-18 21:47:31 -07:00
Awni Hannun
63f663d9c6 fix cuda manylinux version to match others (#2388) 2025-07-18 21:02:16 -07:00
19 changed files with 131 additions and 219 deletions

View File

@@ -366,7 +366,7 @@ jobs:
type: string type: string
default: "" default: ""
machine: machine:
image: linux-cuda-12:default image: linux-cuda-12:2024.11.1
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
steps: steps:
- checkout - checkout

View File

@@ -377,4 +377,10 @@ void copy_cpu_inplace(
}); });
} }
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -30,4 +30,7 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt, const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt); const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return {arr, false}; return {arr, false};
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); return {contiguous_copy_cpu(arr, stream), true};
copy_cpu(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
} }
}; };
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
} }
return in; return in;
} else { } else {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy = contiguous_copy_cpu(in, s);
copy_cpu(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy); out.copy_shared_buffer(arr_copy);
return arr_copy; return arr_copy;
} }

View File

@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_cpu(x, s);
copy_cpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy); encoder.add_temporary(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -136,9 +136,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
return std::make_tuple(true, sty, arr, false); return std::make_tuple(true, sty, arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
array arr_copy = contiguous_copy_cpu(arr, s);
return std::make_tuple(false, stx, arr_copy, true); return std::make_tuple(false, stx, arr_copy, true);
} }
}; };

View File

@@ -712,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return std::make_pair(arr, false); return std::make_pair(arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); return std::make_pair(contiguous_copy_cpu(arr, s), true);
copy_cpu(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
} }
}; };

View File

@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity // Ensure contiguity
auto in = inputs[0]; auto in = inputs[0];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); in = contiguous_copy_cpu(in, stream());
copy_cpu(in, arr_copy, CopyType::General, stream()); encoder.add_temporary(in);
in = arr_copy;
encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));

View File

@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_cpu(x, s);
copy_cpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -17,6 +17,52 @@ namespace cu {
constexpr int page_size = 16384; constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
end_ = reinterpret_cast<void*>(
reinterpret_cast<char*>(buffer_) + small_pool_size);
next_free_ = reinterpret_cast<Block*>(buffer_);
auto num_blocks = small_pool_size / small_block_size;
auto curr = next_free_;
for (size_t i = 0; i < num_blocks - 1; ++i) {
curr->next = reinterpret_cast<Block*>(
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(buffer_));
}
void* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
next_free_ = next_free_->next;
return static_cast<void*>(b);
}
void SmallSizePool::free(void* p) {
auto b = static_cast<Block*>(p);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(void* p) {
return (p >= buffer_) && (p < end_);
}
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
page_size, page_size,
@@ -36,7 +82,9 @@ Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size; auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) { if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
size = next_power_of_2(size); size = next_power_of_2(size);
} else { } else {
size = page_size * ((size + page_size - 1) / page_size); size = page_size * ((size + page_size - 1) / page_size);
@@ -53,11 +101,19 @@ Buffer CudaAllocator::malloc(size_t size) {
lock.unlock(); lock.unlock();
buf = new CudaBuffer{nullptr, size}; buf = new CudaBuffer{nullptr, size};
// Try the scalar pool first
if (size <= small_block_size) {
buf->data = scalar_pool_.malloc();
}
if (!buf->data) {
cudaError_t err = cudaMallocManaged(&buf->data, size); cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); "cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
} }
}
lock.lock(); lock.lock();
} }
active_memory_ += size; active_memory_ += size;
@@ -116,8 +172,12 @@ void CudaAllocator::cuda_free(void* buf) {
return; return;
} }
} }
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf); cudaFree(buf);
} }
}
size_t CudaAllocator::get_active_memory() const { size_t CudaAllocator::get_active_memory() const {
return active_memory_; return active_memory_;

View File

@@ -22,6 +22,28 @@ struct CudaBuffer {
size_t size; size_t size;
}; };
class SmallSizePool {
private:
struct Block {
Block* next;
};
void* buffer_{nullptr};
Block* next_free_{nullptr};
void* end_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
void* malloc();
void free(void* p);
bool in_pool(void* p);
};
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
public: public:
Buffer malloc(size_t size) override; Buffer malloc(size_t size) override;
@@ -60,6 +82,7 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
SmallSizePool scalar_pool_;
}; };
CudaAllocator& allocator(); CudaAllocator& allocator();

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -115,7 +115,7 @@ __global__ void arg_reduce_general(
T vals[N_READS]; T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x; auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS); best = op.reduce_many(best, vals, tid * N_READS);
} }

View File

@@ -49,6 +49,20 @@ store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
to[offset] = vec; to[offset] = vec;
} }
// Helper for accessing strided data.
template <typename T>
struct StridedIterator {
T it;
int64_t stride;
__host__ __device__ StridedIterator(T it, int64_t stride)
: it(it), stride(stride) {}
__host__ __device__ auto operator[](int i) const {
return it[i * stride];
}
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@@ -1,121 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <cuda/std/utility>
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core::cu {
// Iterating non-contiguous array.
template <typename Iterator, typename IdxT = int64_t>
class general_iterator
: public thrust::
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides)
: super_t(it),
index_(index),
ndim_(ndim),
shape_(cuda::std::move(shape)),
strides_(cuda::std::move(strides)) {}
__host__ __device__ IdxT index() const {
return index_;
}
__host__ __device__ const Shape& shape() const {
return shape_;
}
__host__ __device__ const Strides& strides() const {
return strides_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const general_iterator& other) const {
return this->base() == other.base() && this->index() == other.index();
}
__host__ __device__ void advance(difference_type n) {
this->index_ += n;
}
__host__ __device__ void increment() {
this->index_ += 1;
}
__host__ __device__ void decrement() {
this->index_ -= 1;
}
__host__ __device__ difference_type
distance_to(const general_iterator& other) const {
_CCCL_ASSERT(
this->base() == other.base(),
"Underlying iterator must point to same base iterator");
return other.index() - this->index();
}
// The dereference is device-only to avoid accidental running in host.
__device__ typename super_t::reference dereference() const {
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
return *(this->base() + offset);
}
IdxT index_;
int ndim_;
Shape shape_;
Strides strides_;
};
template <typename IdxT, typename Iterator>
__host__ __device__ auto make_general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides) {
return general_iterator<Iterator, IdxT>(
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterator(
Iterator it,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
return make_general_iterator<IdxT>(
it, 0, shape.size(), const_param(shape), const_param(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterators(
Iterator it,
IdxT size,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
auto ndim = shape.size();
auto shape_arg = const_param(shape);
auto strides_arg = const_param(strides);
return std::make_pair(
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
}
} // namespace mlx::core::cu

View File

@@ -1,60 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
@@ -105,8 +104,8 @@ __global__ void layer_norm(
T wn[N_READS]; T wn[N_READS];
T bn[N_READS]; T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer; float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i]; xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
@@ -162,7 +161,7 @@ __global__ void layer_norm_vjp(
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, mean); cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean; float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i]; float wi = wn[i];
@@ -185,7 +184,7 @@ __global__ void layer_norm_vjp(
T gn[N_READS]; T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer; float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i]; float wi = wn[i];

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
@@ -89,7 +88,7 @@ __global__ void rms_norm(
T xn[N_READS]; T xn[N_READS];
T wn[N_READS]; T wn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = static_cast<float>(xn[i]) * normalizer; float norm = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm); xn[i] = wn[i] * static_cast<T>(norm);
@@ -132,7 +131,7 @@ __global__ void rms_norm_vjp(
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0)); cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]); float t = static_cast<float>(xn[i]);
float wi = wn[i]; float wi = wn[i];
@@ -154,7 +153,7 @@ __global__ void rms_norm_vjp(
T gn[N_READS]; T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = xn[i]; float xi = xn[i];
float wi = wn[i]; float wi = wn[i];

View File

@@ -3,7 +3,6 @@
#include "mlx/backend/common/unary.h" #include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_39_x86_64 \ --plat manylinux_2_35_x86_64 \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* \ --exclude libnvrtc* \
-w wheel_tmp -w wheel_tmp