* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
Josh Soref 2024-01-02 00:08:17 -05:00 committed by GitHub
parent 144ecff849
commit 44c1ce5e6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 117 additions and 117 deletions

View File

@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
ax.legend()
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []

View File

@ -62,7 +62,7 @@ def make_predicate(positive_filter, negative_filter):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
parser.add_argument(
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
)

View File

@ -12,7 +12,7 @@ include(CMakeParseArguments)
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of depedency files (like headers)
# DEPS: List of dependency files (like headers)
#
macro(mlx_build_metallib)
# Parse args
@ -32,7 +32,7 @@ macro(mlx_build_metallib)
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metllib build command
# Prepare metallib build command
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal

View File

@ -26,7 +26,7 @@ python -m http.server <port>
and point your browser to `http://localhost:<port>`.
### Push to Github Pages
### Push to GitHub Pages
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
the docs. Then force add the `build/html` directory:

View File

@ -15,7 +15,7 @@ Introducing the Example
-----------------------
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
@ -69,7 +69,7 @@ C++ API:
.. code-block:: C++
/**
* Scale and sum two vectors elementwise
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
This operation now handles the following:
#. Upcast inputs and resolve the the output data type.
#. Upcast inputs and resolve the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
@ -305,7 +305,7 @@ if we encounter an unexpected type.
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@ -485,7 +485,7 @@ each data type.
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
@ -537,7 +537,7 @@ below.
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
@ -568,7 +568,7 @@ below.
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
@ -581,7 +581,7 @@ to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and commiting the associated
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
@ -601,8 +601,8 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
@ -642,7 +642,7 @@ own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitve along given axis */
/** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@ -666,7 +666,7 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
@ -697,7 +697,7 @@ are already provided, adding our :meth:`axpby` becomes very simple!
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
@ -840,7 +840,7 @@ This will result in a directory structure as follows:
| ...
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same strucutre as
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.

View File

@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
The design of MLX is inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types

View File

@ -57,7 +57,7 @@ void array_basics() {
assert(z.shape(0) == 2);
assert(z.shape(1) == 2);
// To actually run the compuation you must evaluate `z`.
// To actually run the computation you must evaluate `z`.
// Under the hood, mlx records operations in a graph.
// The variable `z` is a node in the graph which points to its operation
// and inputs. When `eval` is called on an array (or arrays), the array and

View File

@ -26,7 +26,7 @@ namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
/**
* Scale and sum two vectors elementwise
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@ -91,21 +91,21 @@ void axpby_impl(
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@ -192,7 +192,7 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
#else // Accelerate not avaliable
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -254,7 +254,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
@ -287,7 +287,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
@ -311,8 +311,8 @@ array Axpby::jvp(
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
@ -345,7 +345,7 @@ std::vector<array> Axpby::vjp(
return vjps;
}
/** Vectorize primitve along given axis */
/** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {

View File

@ -12,7 +12,7 @@ namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
/**
* Scale and sum two vectors elementwise
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@ -39,7 +39,7 @@ class Axpby : public Primitive {
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unecessary allocations, the evaluation function
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, array& out) override;

View File

@ -59,5 +59,5 @@ template <typename T>
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);

View File

@ -23,7 +23,7 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``

View File

@ -37,7 +37,7 @@ void free(Buffer buffer);
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base clase for a memory allocator. */
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;

View File

@ -129,7 +129,7 @@ array::ArrayDesc::ArrayDesc(
}
// Needed because the Primitive type used in array.h is incomplete and the
// compiler needs to see the call to the desctructor after the type is complete.
// compiler needs to see the call to the destructor after the type is complete.
array::ArrayDesc::~ArrayDesc() = default;
array::ArrayIterator::reference array::ArrayIterator::operator*() const {

View File

@ -13,7 +13,7 @@ namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
void swap_endianess(uint8_t* data_bytes, size_t N) {
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
@ -39,13 +39,13 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
}
}

View File

@ -165,7 +165,7 @@ Buffer MetalAllocator::malloc(size_t size) {
// Prepare to allocate new memory as needed
if (!buf) {
// If we are under very high memoory pressure, we don't allocate further
// If we are under very high memory pressure, we don't allocate further
if (device_->currentAllocatedSize() >= block_limit_) {
return Buffer{nullptr};
}

View File

@ -68,7 +68,7 @@ void explicit_gemm_conv_1D_gpu(
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Peform gemm
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
s,
@ -260,7 +260,7 @@ void explicit_gemm_conv_2D_gpu(
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Peform gemm
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
s,

View File

@ -102,7 +102,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument bufer
// Allocate the argument buffer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument bufer
// Allocate the argument buffer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder

View File

@ -114,7 +114,7 @@ template <typename T, typename Op, int N_READS>
// 4. Reduce among them and go to 3
// 4. Reduce in each simd_group
// 6. Write in the thread local memory
// 6. Reduce them accross thread group
// 6. Reduce them across thread group
// 7. Write the output without need for atomic
Op op;

View File

@ -45,7 +45,7 @@ struct complex64_t {
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) constant : real(x), imag(0) {}
// Converstions from complex64_t
// Conversions from complex64_t
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>

View File

@ -105,7 +105,7 @@ struct Conv2DInputBlockLoader {
}
}
// Zero pad otherwize
// Zero pad otherwise
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
@ -334,7 +334,7 @@ struct Conv2DBlockMMA {
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into resulr simdgroup matrices
// Multiply and accumulate into result simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)

View File

@ -93,13 +93,13 @@ struct BlockLoader {
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
}
// Read all valid indcies into tmp_val
// Read all valid indices into tmp_val
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
}
// Zero out uneeded values
// Zero out unneeded values
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
@ -241,7 +241,7 @@ struct BlockMMA {
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into resulr simdgroup matrices
// Multiply and accumulate into result simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)

View File

@ -28,7 +28,7 @@ struct GEMVKernel {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
@ -42,7 +42,7 @@ struct GEMVKernel {
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
@ -166,7 +166,7 @@ template <
struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
@ -180,7 +180,7 @@ struct GEMVTKernel {
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partialy overlaps with the matrix is shifted inwards
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix

View File

@ -65,7 +65,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
in += grid_size * N_READS;
}
// Sepate case for the last set as we close the reduction size
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;

View File

@ -592,7 +592,7 @@ template <
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
@ -777,8 +777,8 @@ template <
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \

View File

@ -61,7 +61,7 @@ inline void mps_matmul(
// 2. Only one of a or b has batch_size_out matrices worth of data and
// the other has matrix worth of data
// The matrix dimsenisons of a and b are sure to be regularly strided
// The matrix dimensions of a and b are sure to be regularly strided
if (batch_size_out > 1) {
// No broadcasting defaults
auto batch_size_a = a.data_size() / (M * K);

View File

@ -40,7 +40,7 @@ void all_reduce_dispatch(
// Set grid dimensions
// We make sure each thread has enough to do by making it read in
// atleast n_reads inputs
// at least n_reads inputs
int n_reads = REDUCE_N_READS;
// mod_in_size gives us the groups of n_reads needed to go over the entire
@ -176,7 +176,7 @@ void strided_reduce_general_dispatch(
// We spread outputs over the x dimension and inputs over the y dimension
// Threads with the same lid.x in a given threadgroup work on the same
// output and each thread in the y dimension accumlates for that output
// output and each thread in the y dimension accumulates for that output
uint threadgroup_dim_x = std::min(out_size, 128ul);
uint threadgroup_dim_y =
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;

View File

@ -165,10 +165,10 @@ void multi_block_sort(
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
ping = !ping;
// Do partiton
// Do partition
{
std::ostringstream kname;
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());

View File

@ -18,7 +18,7 @@ void set_array_buffer(
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
// MTL::Resource usage through argument buffer needs to be explicity
// MTL::Resource usage through argument buffer needs to be explicitly
// flagged to enable hazard tracking
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
}

View File

@ -45,7 +45,7 @@ array fft_impl(
throw std::invalid_argument(msg.str());
}
// In the following shape manipulations there are three cases to consdier:
// In the following shape manipulations there are three cases to consider:
// 1. In a complex to complex transform (fftn / ifftn) the output
// and input shapes are the same.
// 2. In a real to complex transform (rfftn) n specifies the input dims

View File

@ -155,7 +155,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
// Read and check version
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
throw std::runtime_error(
"[load] Unsupport npy format version in " + in_stream->label());
"[load] Unsupported npy format version in " + in_stream->label());
}
// Read header len and header

View File

@ -247,7 +247,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
array tril(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[tril] array must be atleast 2-D");
throw std::invalid_argument("[tril] array must be at least 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
return where(mask, x, zeros_like(x, s), s);
@ -255,7 +255,7 @@ array tril(array x, int k, StreamOrDevice s /* = {} */) {
array triu(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[triu] array must be atleast 2-D");
throw std::invalid_argument("[triu] array must be at least 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
return where(mask, zeros_like(x, s), x, s);
@ -350,7 +350,7 @@ array squeeze(
ax = ax < 0 ? ax + a.ndim() : ax;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axies " << ax << " for array with " << a.ndim()
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
@ -405,7 +405,7 @@ array expand_dims(
ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg;
msg << "[squeeze] Invalid axies " << ax << " for output array with "
msg << "[squeeze] Invalid axes " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@ -478,7 +478,7 @@ array slice(
// If strides are negative, slice and then make a copy with axes flipped
if (negatively_strided_axes.size() > 0) {
// First, take the slice of the positvely strided axes
// First, take the slice of the positively strided axes
auto out = array(
out_shape,
a.dtype(),
@ -517,7 +517,7 @@ array slice(
// Gather moves the axis up, remainder needs to be squeezed
out_reshape[i] = indices[i].size();
// Gather moves the axis up, needs to be tranposed
// Gather moves the axis up, needs to be transposed
t_axes[ax] = i;
}

View File

@ -214,7 +214,7 @@ array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
/** Repeate an array along an axis. */
/** Repeat an array along an axis. */
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
array repeat(const array& arr, int repeats, StreamOrDevice s = {});

View File

@ -49,7 +49,7 @@ class Primitive {
* A primitive must know how to evaluate itself on
* the CPU/GPU for the given inputs and populate the output array.
*
* To avoid unecessary allocations, the evaluation function
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
virtual void eval_cpu(const std::vector<array>& inputs, array& out) = 0;
@ -84,7 +84,7 @@ class Primitive {
/** Print the primitive. */
virtual void print(std::ostream& os) = 0;
/** Equivalence check defaults to false unless overriden by the primitive */
/** Equivalence check defaults to false unless overridden by the primitive */
virtual bool is_equivalent(const Primitive& other) const {
return false;
}

View File

@ -232,7 +232,7 @@ array truncated_normal(
auto u = uniform(a, b, shape, dtype, key, s);
auto out = multiply(sqrt2, erfinv(u, s), s);
// Clip in bouds
// Clip in bounds
return maximum(minimum(upper_t, out, s), lower_t, s);
}

View File

@ -16,7 +16,7 @@ class KeySequence {
void seed(uint64_t seed);
array next();
// static defualt
// static default
static KeySequence& default_() {
static KeySequence ks(0);
return ks;

View File

@ -80,7 +80,7 @@ ValueAndGradFn value_and_grad(
/**
* Returns a function which computes the value and gradient of the input
* function with repsect to a single input array.
* function with respect to a single input array.
**/
ValueAndGradFn inline value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
@ -132,7 +132,7 @@ std::function<std::vector<array>(const std::vector<array>&)> inline grad(
/**
* Returns a function which computes the gradient of the input function with
* repsect to a single input array.
* respect to a single input array.
*
* The function being differentiated takes a vector of arrays and returns an
* array. The optional `argnum` index specifies which the argument to compute

View File

@ -68,7 +68,7 @@ struct _MLX_Float16 {
inf_scale.u = uint32_t(0x77800000);
zero_scale.u = uint32_t(0x08800000);
// Combine with magic and let addition do rouding
// Combine with magic and let addition do rounding
magic_bits.u = x_expo_32;
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;

View File

@ -198,7 +198,7 @@ class BatchNorm(Module):
batch, ``C`` is the number of features or channels, and ``L`` is the
sequence length. The output has the same shape as the input. For
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
the height and width respecitvely.
the height and width respectively.
For more information on Batch Normalization, see the original paper `Batch
Normalization: Accelerating Deep Network Training by Reducing Internal

View File

@ -253,7 +253,7 @@ class AdaDelta(Optimizer):
rho (float, optional): The coefficient :math:`\rho` used for computing a
running average of squared gradients. Default: ``0.9``
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
numerical stability. Ddefault: `1e-8`
numerical stability. Default: `1e-8`
"""
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):

View File

@ -507,7 +507,7 @@ void init_array(py::module_& m) {
array_class
.def_property_readonly(
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
.def_property_readonly(
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
.def_property_readonly(
@ -559,7 +559,7 @@ void init_array(py::module_& m) {
If the array has more than one dimension then the result is a nested
list of lists.
The value type of the list correpsonding to the last dimension is either
The value type of the list corresponding to the last dimension is either
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
)pbdoc")
.def("__array__", &mlx_array_to_np)

View File

@ -1263,7 +1263,7 @@ void init_ops(py::module_& m) {
If the axis is not specified the array is treated as a flattened
1-D array prior to performing the take.
As an example, if the ``axis=1`` this is equialent to ``a[:, indices, ...]``.
As an example, if the ``axis=1`` this is equivalent to ``a[:, indices, ...]``.
Args:
a (array): Input array.
@ -1742,7 +1742,7 @@ void init_ops(py::module_& m) {
"a"_a,
py::pos_only(),
"source"_a,
"destiantion"_a,
"destination"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
@ -2253,7 +2253,7 @@ void init_ops(py::module_& m) {
will be of elements less or equal to the element at the ``kth``
index and all indices after will be of elements greater or equal
to the element at the ``kth`` index.
axis (int or None, optional): Optional axis to partiton over.
axis (int or None, optional): Optional axis to partition over.
If ``None``, this partitions over the flattened array.
If unspecified, it defaults to ``-1``.
@ -2426,13 +2426,13 @@ void init_ops(py::module_& m) {
R"pbdoc(
repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
Repeate an array along a specified axis.
Repeat an array along a specified axis.
Args:
array (array): Input array.
repeats (int): The number of repetitions for each element.
axis (int, optional): The axis in which to repeat the array along. If
unspecified it uses the flattened array of the input and repeates
unspecified it uses the flattened array of the input and repeats
along axis 0.
stream (Stream, optional): Stream or device. Defaults to ``None``.
@ -3050,7 +3050,7 @@ void init_ops(py::module_& m) {
Round to the given number of decimals.
Bascially performs:
Basically performs:
.. code-block:: python

View File

@ -212,7 +212,7 @@ void init_random(py::module_& parent_module) {
upper (scalar or array): Upper bound of the domain.
shape (list(int), optional): The shape of the output.
Default is ``()``.
dtype (Dtype, optinoal): The data type of the output.
dtype (Dtype, optional): The data type of the output.
Default is ``float32``.
key (array, optional): A PRNG key. Default: None.

View File

@ -952,7 +952,7 @@ class TestArray(mlx_tests.MLXTestCase):
b_mx = a_mx[25:-50:-3]
self.assertTrue(np.array_equal(b_np, b_mx))
# Negatie slice and ascending bounds
# Negative slice and ascending bounds
b_np = a_np[0:20:-3]
b_mx = a_mx[0:20:-3]
self.assertTrue(np.array_equal(b_np, b_mx))

View File

@ -53,10 +53,10 @@ class TestBlas(mlx_tests.MLXTestCase):
for dtype in self.dtypes:
np_dtype = getattr(np, dtype)
base_shapes = [4, 8, 16, 32, 64, 128]
pertubations = [-2, -1, 0, 1, 2]
perturbations = [-2, -1, 0, 1, 2]
for dim in base_shapes:
for p in pertubations:
for p in perturbations:
shape_a = (dim + p, dim + p)
shape_b = (dim + p, dim + p)
self.__gemm_test(shape_a, shape_b, np_dtype)
@ -81,12 +81,12 @@ class TestBlas(mlx_tests.MLXTestCase):
for B, M, N, K in shapes:
with self.subTest(tranpose="nn"):
with self.subTest(transpose="nn"):
shape_a = (B, M, K)
shape_b = (B, K, N)
self.__gemm_test(shape_a, shape_b, np_dtype)
with self.subTest(tranpose="nt"):
with self.subTest(transpose="nt"):
shape_a = (B, M, K)
shape_b = (B, N, K)
self.__gemm_test(
@ -97,7 +97,7 @@ class TestBlas(mlx_tests.MLXTestCase):
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
)
with self.subTest(tranpose="tn"):
with self.subTest(transpose="tn"):
shape_a = (B, K, M)
shape_b = (B, K, N)
self.__gemm_test(
@ -108,7 +108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
)
with self.subTest(tranpose="tt"):
with self.subTest(transpose="tt"):
shape_a = (B, K, M)
shape_b = (B, N, K)
self.__gemm_test(
@ -191,7 +191,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Batched matmul with simple broadast
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
c_npy = a_npy @ b_npy
@ -213,7 +213,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
# Batched and transposed matmul with simple broadast
# Batched and transposed matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)

View File

@ -88,7 +88,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(a.dtype, mx.float32)
self.assertEqual(a.item(), 3.0)
# Check comibinations with mlx arrays
# Check combinations with mlx arrays
a = mx.add(mx.array(True), False)
self.assertEqual(a.dtype, mx.bool_)
self.assertEqual(a.item(), True)

View File

@ -76,7 +76,7 @@ TEST_CASE("test arg reduce small") {
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
INFO("Skipping arg reduction gpu tests");
return;
}
@ -106,7 +106,7 @@ TEST_CASE("test arg reduce small") {
TEST_CASE("test arg reduce against cpu") {
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
INFO("Skipping arg reduction gpu tests");
return;
}
@ -148,7 +148,7 @@ void test_arg_reduce_small_bool(
TEST_CASE("test arg reduce bool") {
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
INFO("Skipping arg reduction gpu tests");
return;
}
auto x = array(
@ -201,7 +201,7 @@ TEST_CASE("test arg reduce irregular strides") {
Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});
if (!metal::is_available()) {
INFO("Skiping arg reduction gpu tests");
INFO("Skipping arg reduction gpu tests");
return;
}
}

View File

@ -989,7 +989,7 @@ TEST_CASE("test as_strided grads") {
}
TEST_CASE("test jvp from vjp") {
// Unary elementwise ops
// Unary element-wise ops
{
auto x = random::uniform({5, 10});
eval(x);
@ -1022,7 +1022,7 @@ TEST_CASE("test jvp from vjp") {
CHECK(compute_derivs(mlx::core::rsqrt));
}
// Binary elementwise ops
// Binary element-wise ops
{
auto x = random::uniform({5, 10});
auto y = random::uniform({5, 10});

View File

@ -7,7 +7,7 @@
using namespace mlx::core;
TEST_CASE("test arange") {
// Check type is inferred correclty
// Check type is inferred correctly
{
auto x = arange(10);
CHECK_EQ(x.dtype(), int32);

View File

@ -1411,7 +1411,7 @@ TEST_CASE("test broadcast") {
x.eval();
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0, 1});
// Broadcast on transposed arrray works
// Broadcast on transposed array works
x = array({0, 1, 2, 3, 4, 5}, {2, 3});
x = broadcast_to(transpose(x), {2, 3, 2});
CHECK_EQ(x.shape(), std::vector<int>{2, 3, 2});
@ -1733,7 +1733,7 @@ TEST_CASE("test scatter") {
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
// Irregular strided index and reduce collison test
// Irregular strided index and reduce collision test
in = zeros({10}, float32);
inds = broadcast_to(array(3), {10});
updates = ones({10, 1}, float32);
@ -1750,7 +1750,7 @@ TEST_CASE("test scatter") {
out = scatter_max(array(1), {}, array(2), std::vector<int>{});
CHECK_EQ(out.item<int>(), 2);
// Irregularaly strided updates test
// Irregularly strided updates test
in = ones({3, 3});
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
inds = array({0});