This commit is contained in:
hdeng-apple 2025-04-29 22:26:05 +08:00 committed by GitHub
parent 99b9868859
commit 167b759a38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 13 additions and 13 deletions

View File

@ -356,7 +356,7 @@ class array {
} }
enum Status { enum Status {
// The ouptut of a computation which has not been scheduled. // The output of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`. // For example, the status of `x` in `auto x = a + b`.
unscheduled, unscheduled,

View File

@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so
read/write performance is important. read/write performance is important.
Where possible, we read 128 bits sequentially in each thread, Where possible, we read 128 bits sequentially in each thread,
coalesced with accesses from adajcent threads for optimal performance. coalesced with accesses from adjacent threads for optimal performance.
We implement specialized reading/writing for: We implement specialized reading/writing for:
- FFT - FFT

View File

@ -95,7 +95,7 @@ template <
Q += tidl.z * params->Q_strides[0] + // Batch Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce tidl.x * BQ * params->Q_strides[2]; // Sequence
ulong kv_head_idx = int(tid.y) / params->gqa_factor; ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch K += tidl.z * params->K_strides[0] + // Batch
@ -106,7 +106,7 @@ template <
O += tidl.z * params->O_strides[0] + // Batch O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce tidl.x * BQ * params->O_strides[2]; // Sequence
if (has_mask) { if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch mask += tidl.z * mask_params->M_strides[0] + // Batch

View File

@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
} }
// Zero out uneeded values // Zero out unneeded values
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) { for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
@ -240,7 +240,7 @@ struct BlockLoaderT {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
} }
// Zero out uneeded values // Zero out unneeded values
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) { for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general(
// Store results to device memory // Store results to device memory
{ {
// Adjust for simdgroup and thread locatio // Adjust for simdgroup and thread location
int offset_m = c_row + mma_op.sm; int offset_m = c_row + mma_op.sm;
int offset_n = c_col + mma_op.sn; int offset_n = c_col + mma_op.sn;
C += offset_n; C += offset_n;

View File

@ -113,7 +113,7 @@ struct BlockLoader {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
} }
// Zero out uneeded values // Zero out unneeded values
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) { for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);

View File

@ -18,7 +18,7 @@ void Compiled::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
throw std::runtime_error( throw std::runtime_error(
"[Compiled::eval_cpu] CPU compialtion not supported on the platform."); "[Compiled::eval_cpu] CPU compilation not supported on the platform.");
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -569,7 +569,7 @@ inline array std(const array& a, StreamOrDevice s = {}) {
return std(a, false, 0, to_stream(s)); return std(a, false, 0, to_stream(s));
} }
/** Computes the standard deviatoin of the elements of an array along the given /** Computes the standard deviation of the elements of an array along the given
* axes */ * axes */
array std( array std(
const array& a, const array& a,

View File

@ -223,7 +223,7 @@ array multivariate_normal(
auto n = mean.shape(-1); auto n = mean.shape(-1);
// Check shapes comatibility of mean and cov // Check shapes compatibility of mean and cov
if (cov.shape(-1) != cov.shape(-2)) { if (cov.shape(-1) != cov.shape(-2)) {
throw std::invalid_argument( throw std::invalid_argument(
"[multivariate_normal] last two dimensions of cov must be equal."); "[multivariate_normal] last two dimensions of cov must be equal.");
@ -402,7 +402,7 @@ array categorical(
if (broadcast_shapes(shape, reduced_shape) != shape) { if (broadcast_shapes(shape, reduced_shape) != shape) {
std::ostringstream msg; std::ostringstream msg;
msg << "[categorical] Requested shape " << shape msg << "[categorical] Requested shape " << shape
<< " is not broadcast compatable with reduced logits shape" << " is not broadcast compatible with reduced logits shape"
<< reduced_shape << "."; << reduced_shape << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@ -422,7 +422,7 @@ void init_random(nb::module_& parent_module) {
axis (int, optional): The axis which specifies the distribution. axis (int, optional): The axis which specifies the distribution.
Default: ``-1``. Default: ``-1``.
shape (list(int), optional): The shape of the output. This must shape (list(int), optional): The shape of the output. This must
be broadcast compatable with ``logits.shape`` with the ``axis`` be broadcast compatible with ``logits.shape`` with the ``axis``
dimension removed. Default: ``None`` dimension removed. Default: ``None``
num_samples (int, optional): The number of samples to draw from each num_samples (int, optional): The number of samples to draw from each
of the categorical distributions in ``logits``. The output will have of the categorical distributions in ``logits``. The output will have