mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Spelling (#342)
* spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
@@ -165,7 +165,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
|
||||
// Prepare to allocate new memory as needed
|
||||
if (!buf) {
|
||||
// If we are under very high memoory pressure, we don't allocate further
|
||||
// If we are under very high memory pressure, we don't allocate further
|
||||
if (device_->currentAllocatedSize() >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
@@ -68,7 +68,7 @@ void explicit_gemm_conv_1D_gpu(
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
@@ -260,7 +260,7 @@ void explicit_gemm_conv_2D_gpu(
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
|
@@ -102,7 +102,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
|
@@ -114,7 +114,7 @@ template <typename T, typename Op, int N_READS>
|
||||
// 4. Reduce among them and go to 3
|
||||
// 4. Reduce in each simd_group
|
||||
// 6. Write in the thread local memory
|
||||
// 6. Reduce them accross thread group
|
||||
// 6. Reduce them across thread group
|
||||
// 7. Write the output without need for atomic
|
||||
Op op;
|
||||
|
||||
|
@@ -45,7 +45,7 @@ struct complex64_t {
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Converstions from complex64_t
|
||||
// Conversions from complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
|
@@ -105,7 +105,7 @@ struct Conv2DInputBlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwize
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
@@ -334,7 +334,7 @@ struct Conv2DBlockMMA {
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
|
@@ -93,13 +93,13 @@ struct BlockLoader {
|
||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||
}
|
||||
|
||||
// Read all valid indcies into tmp_val
|
||||
// Read all valid indices into tmp_val
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||
@@ -241,7 +241,7 @@ struct BlockMMA {
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
|
@@ -28,7 +28,7 @@ struct GEMVKernel {
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
@@ -42,7 +42,7 @@ struct GEMVKernel {
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
@@ -166,7 +166,7 @@ template <
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
@@ -180,7 +180,7 @@ struct GEMVTKernel {
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
|
||||
|
@@ -65,7 +65,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
|
@@ -592,7 +592,7 @@ template <
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
@@ -777,8 +777,8 @@ template <
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
|
@@ -61,7 +61,7 @@ inline void mps_matmul(
|
||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||
// the other has matrix worth of data
|
||||
|
||||
// The matrix dimsenisons of a and b are sure to be regularly strided
|
||||
// The matrix dimensions of a and b are sure to be regularly strided
|
||||
if (batch_size_out > 1) {
|
||||
// No broadcasting defaults
|
||||
auto batch_size_a = a.data_size() / (M * K);
|
||||
|
@@ -40,7 +40,7 @@ void all_reduce_dispatch(
|
||||
// Set grid dimensions
|
||||
|
||||
// We make sure each thread has enough to do by making it read in
|
||||
// atleast n_reads inputs
|
||||
// at least n_reads inputs
|
||||
int n_reads = REDUCE_N_READS;
|
||||
|
||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||
@@ -176,7 +176,7 @@ void strided_reduce_general_dispatch(
|
||||
|
||||
// We spread outputs over the x dimension and inputs over the y dimension
|
||||
// Threads with the same lid.x in a given threadgroup work on the same
|
||||
// output and each thread in the y dimension accumlates for that output
|
||||
// output and each thread in the y dimension accumulates for that output
|
||||
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||
uint threadgroup_dim_y =
|
||||
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||
|
@@ -165,10 +165,10 @@ void multi_block_sort(
|
||||
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||
ping = !ping;
|
||||
|
||||
// Do partiton
|
||||
// Do partition
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
||||
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
@@ -18,7 +18,7 @@ void set_array_buffer(
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
// MTL::Resource usage through argument buffer needs to be explicity
|
||||
// MTL::Resource usage through argument buffer needs to be explicitly
|
||||
// flagged to enable hazard tracking
|
||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||
}
|
||||
|
Reference in New Issue
Block a user