mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Spelling (#342)
* spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
parent
144ecff849
commit
44c1ce5e6a
@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
||||
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||
|
||||
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
ax.legend()
|
||||
|
||||
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
|
@ -62,7 +62,7 @@ def make_predicate(positive_filter, negative_filter):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
|
||||
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
|
||||
parser.add_argument(
|
||||
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ include(CMakeParseArguments)
|
||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||
# SOURCES: List of source files
|
||||
# INCLUDE_DIRS: List of include dirs
|
||||
# DEPS: List of depedency files (like headers)
|
||||
# DEPS: List of dependency files (like headers)
|
||||
#
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
@ -32,7 +32,7 @@ macro(mlx_build_metallib)
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
|
||||
# Prepare metllib build command
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
|
@ -26,7 +26,7 @@ python -m http.server <port>
|
||||
|
||||
and point your browser to `http://localhost:<port>`.
|
||||
|
||||
### Push to Github Pages
|
||||
### Push to GitHub Pages
|
||||
|
||||
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
||||
the docs. Then force add the `build/html` directory:
|
||||
|
@ -15,7 +15,7 @@ Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say that you would like an operation that takes in two arrays,
|
||||
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
|
||||
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
||||
respectively, and then adds them together to get the result
|
||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||
writing out a function as follows:
|
||||
@ -69,7 +69,7 @@ C++ API:
|
||||
.. code-block:: C++
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
This operation now handles the following:
|
||||
|
||||
#. Upcast inputs and resolve the the output data type.
|
||||
#. Upcast inputs and resolve the output data type.
|
||||
#. Broadcast the inputs and resolve the output shape.
|
||||
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||
@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the elementwise operation for each output
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
@ -305,7 +305,7 @@ if we encounter an unexpected type.
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while contructing the out array)
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
@ -485,7 +485,7 @@ each data type.
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bflot16, bfloat16_t);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||
@ -537,7 +537,7 @@ below.
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel decelaration at axpby.metal
|
||||
// those in the kernel declaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
@ -568,7 +568,7 @@ below.
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divded among
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
@ -581,7 +581,7 @@ to give us the active metal compute command encoder instead of building a
|
||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and commiting the associated
|
||||
for synchronization. MLX also handles enqueuing and committing the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
:class:`metal::Device` if you would like to study this routine further.
|
||||
|
||||
@ -601,8 +601,8 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primtive
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
@ -642,7 +642,7 @@ own :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitve along given axis */
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@ -666,7 +666,7 @@ Let's look at the overall directory structure first.
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated python package
|
||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
@ -697,7 +697,7 @@ are already provided, adding our :meth:`axpby` becomes very simple!
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
Scale and sum two vectors elementwise
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
@ -840,7 +840,7 @@ This will result in a directory structure as follows:
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .``
|
||||
(in ``extensions/``), the package will be installed with the same strucutre as
|
||||
(in ``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as ``package_data``.
|
||||
|
||||
|
@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
|
||||
|
||||
The design of MLX is inspired by frameworks like `PyTorch
|
||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
||||
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
|
||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||
memory. Operations on MLX arrays can be performed on any of the supported
|
||||
device types without performing data copies. Currently supported device types
|
||||
|
@ -57,7 +57,7 @@ void array_basics() {
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
// To actually run the compuation you must evaluate `z`.
|
||||
// To actually run the computation you must evaluate `z`.
|
||||
// Under the hood, mlx records operations in a graph.
|
||||
// The variable `z` is a node in the graph which points to its operation
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
|
@ -26,7 +26,7 @@ namespace mlx::core {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
@ -91,21 +91,21 @@ void axpby_impl(
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the elementwise operation for each output
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while contructing the out array)
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
@ -192,7 +192,7 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
#else // Accelerate not avaliable
|
||||
#else // Accelerate not available
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -254,7 +254,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel decelaration at axpby.metal
|
||||
// those in the kernel declaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
@ -287,7 +287,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divded among
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
@ -311,8 +311,8 @@ array Axpby::jvp(
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primtive
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
@ -345,7 +345,7 @@ std::vector<array> Axpby::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitve along given axis */
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
@ -12,7 +12,7 @@ namespace mlx::core {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
@ -39,7 +39,7 @@ class Axpby : public Primitive {
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
@ -59,5 +59,5 @@ template <typename T>
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bflot16, bfloat16_t);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
@ -23,7 +23,7 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
Scale and sum two vectors elementwise
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
|
@ -37,7 +37,7 @@ void free(Buffer buffer);
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base clase for a memory allocator. */
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
|
@ -129,7 +129,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
}
|
||||
|
||||
// Needed because the Primitive type used in array.h is incomplete and the
|
||||
// compiler needs to see the call to the desctructor after the type is complete.
|
||||
// compiler needs to see the call to the destructor after the type is complete.
|
||||
array::ArrayDesc::~ArrayDesc() = default;
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
|
@ -13,7 +13,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
void swap_endianess(uint8_t* data_bytes, size_t N) {
|
||||
void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
struct Elem {
|
||||
uint8_t bytes[scalar_size];
|
||||
};
|
||||
@ -39,13 +39,13 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
|
||||
// Prepare to allocate new memory as needed
|
||||
if (!buf) {
|
||||
// If we are under very high memoory pressure, we don't allocate further
|
||||
// If we are under very high memory pressure, we don't allocate further
|
||||
if (device_->currentAllocatedSize() >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ void explicit_gemm_conv_1D_gpu(
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
@ -260,7 +260,7 @@ void explicit_gemm_conv_2D_gpu(
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
|
@ -102,7 +102,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
}
|
||||
|
||||
// Allocate the argument bufer
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
|
@ -114,7 +114,7 @@ template <typename T, typename Op, int N_READS>
|
||||
// 4. Reduce among them and go to 3
|
||||
// 4. Reduce in each simd_group
|
||||
// 6. Write in the thread local memory
|
||||
// 6. Reduce them accross thread group
|
||||
// 6. Reduce them across thread group
|
||||
// 7. Write the output without need for atomic
|
||||
Op op;
|
||||
|
||||
|
@ -45,7 +45,7 @@ struct complex64_t {
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Converstions from complex64_t
|
||||
// Conversions from complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
|
@ -105,7 +105,7 @@ struct Conv2DInputBlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwize
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
@ -334,7 +334,7 @@ struct Conv2DBlockMMA {
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
|
@ -93,13 +93,13 @@ struct BlockLoader {
|
||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
||||
}
|
||||
|
||||
// Read all valid indcies into tmp_val
|
||||
// Read all valid indices into tmp_val
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
// Zero out unneeded values
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
||||
@ -241,7 +241,7 @@ struct BlockMMA {
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into resulr simdgroup matrices
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
|
@ -28,7 +28,7 @@ struct GEMVKernel {
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
|
||||
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
@ -42,7 +42,7 @@ struct GEMVKernel {
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
@ -166,7 +166,7 @@ template <
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
// into blocks of (BM * TM, BN * TN) divided amoung threadgroups
|
||||
// into blocks of (BM * TM, BN * TN) divided among threadgroups
|
||||
// - Every thread works on a block of (TM, TN)
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
@ -180,7 +180,7 @@ struct GEMVTKernel {
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results remain zero)
|
||||
// * The last thread that partialy overlaps with the matrix is shifted inwards
|
||||
// * The last thread that partially overlaps with the matrix is shifted inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
|
||||
|
||||
|
@ -65,7 +65,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
|
@ -592,7 +592,7 @@ template <
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
@ -777,8 +777,8 @@ template <
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
|
@ -61,7 +61,7 @@ inline void mps_matmul(
|
||||
// 2. Only one of a or b has batch_size_out matrices worth of data and
|
||||
// the other has matrix worth of data
|
||||
|
||||
// The matrix dimsenisons of a and b are sure to be regularly strided
|
||||
// The matrix dimensions of a and b are sure to be regularly strided
|
||||
if (batch_size_out > 1) {
|
||||
// No broadcasting defaults
|
||||
auto batch_size_a = a.data_size() / (M * K);
|
||||
|
@ -40,7 +40,7 @@ void all_reduce_dispatch(
|
||||
// Set grid dimensions
|
||||
|
||||
// We make sure each thread has enough to do by making it read in
|
||||
// atleast n_reads inputs
|
||||
// at least n_reads inputs
|
||||
int n_reads = REDUCE_N_READS;
|
||||
|
||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||
@ -176,7 +176,7 @@ void strided_reduce_general_dispatch(
|
||||
|
||||
// We spread outputs over the x dimension and inputs over the y dimension
|
||||
// Threads with the same lid.x in a given threadgroup work on the same
|
||||
// output and each thread in the y dimension accumlates for that output
|
||||
// output and each thread in the y dimension accumulates for that output
|
||||
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||
uint threadgroup_dim_y =
|
||||
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||
|
@ -165,10 +165,10 @@ void multi_block_sort(
|
||||
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||
ping = !ping;
|
||||
|
||||
// Do partiton
|
||||
// Do partition
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
||||
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
@ -18,7 +18,7 @@ void set_array_buffer(
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
// MTL::Resource usage through argument buffer needs to be explicity
|
||||
// MTL::Resource usage through argument buffer needs to be explicitly
|
||||
// flagged to enable hazard tracking
|
||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ array fft_impl(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// In the following shape manipulations there are three cases to consdier:
|
||||
// In the following shape manipulations there are three cases to consider:
|
||||
// 1. In a complex to complex transform (fftn / ifftn) the output
|
||||
// and input shapes are the same.
|
||||
// 2. In a real to complex transform (rfftn) n specifies the input dims
|
||||
|
@ -155,7 +155,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
// Read and check version
|
||||
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
|
||||
throw std::runtime_error(
|
||||
"[load] Unsupport npy format version in " + in_stream->label());
|
||||
"[load] Unsupported npy format version in " + in_stream->label());
|
||||
}
|
||||
|
||||
// Read header len and header
|
||||
|
12
mlx/ops.cpp
12
mlx/ops.cpp
@ -247,7 +247,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[tril] array must be atleast 2-D");
|
||||
throw std::invalid_argument("[tril] array must be at least 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
||||
return where(mask, x, zeros_like(x, s), s);
|
||||
@ -255,7 +255,7 @@ array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[triu] array must be atleast 2-D");
|
||||
throw std::invalid_argument("[triu] array must be at least 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
||||
return where(mask, zeros_like(x, s), x, s);
|
||||
@ -350,7 +350,7 @@ array squeeze(
|
||||
ax = ax < 0 ? ax + a.ndim() : ax;
|
||||
if (ax < 0 || ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[squeeze] Invalid axies " << ax << " for array with " << a.ndim()
|
||||
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -405,7 +405,7 @@ array expand_dims(
|
||||
ax = ax < 0 ? ax + out_ndim : ax;
|
||||
if (ax < 0 || ax >= out_ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[squeeze] Invalid axies " << ax << " for output array with "
|
||||
msg << "[squeeze] Invalid axes " << ax << " for output array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -478,7 +478,7 @@ array slice(
|
||||
|
||||
// If strides are negative, slice and then make a copy with axes flipped
|
||||
if (negatively_strided_axes.size() > 0) {
|
||||
// First, take the slice of the positvely strided axes
|
||||
// First, take the slice of the positively strided axes
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
@ -517,7 +517,7 @@ array slice(
|
||||
// Gather moves the axis up, remainder needs to be squeezed
|
||||
out_reshape[i] = indices[i].size();
|
||||
|
||||
// Gather moves the axis up, needs to be tranposed
|
||||
// Gather moves the axis up, needs to be transposed
|
||||
t_axes[ax] = i;
|
||||
}
|
||||
|
||||
|
@ -214,7 +214,7 @@ array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
||||
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
|
||||
/** Repeate an array along an axis. */
|
||||
/** Repeat an array along an axis. */
|
||||
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
||||
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
||||
|
||||
|
@ -49,7 +49,7 @@ class Primitive {
|
||||
* A primitive must know how to evaluate itself on
|
||||
* the CPU/GPU for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
virtual void eval_cpu(const std::vector<array>& inputs, array& out) = 0;
|
||||
@ -84,7 +84,7 @@ class Primitive {
|
||||
/** Print the primitive. */
|
||||
virtual void print(std::ostream& os) = 0;
|
||||
|
||||
/** Equivalence check defaults to false unless overriden by the primitive */
|
||||
/** Equivalence check defaults to false unless overridden by the primitive */
|
||||
virtual bool is_equivalent(const Primitive& other) const {
|
||||
return false;
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ array truncated_normal(
|
||||
auto u = uniform(a, b, shape, dtype, key, s);
|
||||
auto out = multiply(sqrt2, erfinv(u, s), s);
|
||||
|
||||
// Clip in bouds
|
||||
// Clip in bounds
|
||||
return maximum(minimum(upper_t, out, s), lower_t, s);
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,7 @@ class KeySequence {
|
||||
void seed(uint64_t seed);
|
||||
array next();
|
||||
|
||||
// static defualt
|
||||
// static default
|
||||
static KeySequence& default_() {
|
||||
static KeySequence ks(0);
|
||||
return ks;
|
||||
|
@ -80,7 +80,7 @@ ValueAndGradFn value_and_grad(
|
||||
|
||||
/**
|
||||
* Returns a function which computes the value and gradient of the input
|
||||
* function with repsect to a single input array.
|
||||
* function with respect to a single input array.
|
||||
**/
|
||||
ValueAndGradFn inline value_and_grad(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
@ -132,7 +132,7 @@ std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
||||
|
||||
/**
|
||||
* Returns a function which computes the gradient of the input function with
|
||||
* repsect to a single input array.
|
||||
* respect to a single input array.
|
||||
*
|
||||
* The function being differentiated takes a vector of arrays and returns an
|
||||
* array. The optional `argnum` index specifies which the argument to compute
|
||||
|
@ -68,7 +68,7 @@ struct _MLX_Float16 {
|
||||
inf_scale.u = uint32_t(0x77800000);
|
||||
zero_scale.u = uint32_t(0x08800000);
|
||||
|
||||
// Combine with magic and let addition do rouding
|
||||
// Combine with magic and let addition do rounding
|
||||
magic_bits.u = x_expo_32;
|
||||
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
||||
|
||||
|
@ -198,7 +198,7 @@ class BatchNorm(Module):
|
||||
batch, ``C`` is the number of features or channels, and ``L`` is the
|
||||
sequence length. The output has the same shape as the input. For
|
||||
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
|
||||
the height and width respecitvely.
|
||||
the height and width respectively.
|
||||
|
||||
For more information on Batch Normalization, see the original paper `Batch
|
||||
Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||
|
@ -253,7 +253,7 @@ class AdaDelta(Optimizer):
|
||||
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
||||
running average of squared gradients. Default: ``0.9``
|
||||
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
||||
numerical stability. Ddefault: `1e-8`
|
||||
numerical stability. Default: `1e-8`
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
||||
|
@ -507,7 +507,7 @@ void init_array(py::module_& m) {
|
||||
|
||||
array_class
|
||||
.def_property_readonly(
|
||||
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
||||
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
||||
.def_property_readonly(
|
||||
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
|
||||
.def_property_readonly(
|
||||
@ -559,7 +559,7 @@ void init_array(py::module_& m) {
|
||||
If the array has more than one dimension then the result is a nested
|
||||
list of lists.
|
||||
|
||||
The value type of the list correpsonding to the last dimension is either
|
||||
The value type of the list corresponding to the last dimension is either
|
||||
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
||||
)pbdoc")
|
||||
.def("__array__", &mlx_array_to_np)
|
||||
|
@ -1263,7 +1263,7 @@ void init_ops(py::module_& m) {
|
||||
If the axis is not specified the array is treated as a flattened
|
||||
1-D array prior to performing the take.
|
||||
|
||||
As an example, if the ``axis=1`` this is equialent to ``a[:, indices, ...]``.
|
||||
As an example, if the ``axis=1`` this is equivalent to ``a[:, indices, ...]``.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
@ -1742,7 +1742,7 @@ void init_ops(py::module_& m) {
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"source"_a,
|
||||
"destiantion"_a,
|
||||
"destination"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
@ -2253,7 +2253,7 @@ void init_ops(py::module_& m) {
|
||||
will be of elements less or equal to the element at the ``kth``
|
||||
index and all indices after will be of elements greater or equal
|
||||
to the element at the ``kth`` index.
|
||||
axis (int or None, optional): Optional axis to partiton over.
|
||||
axis (int or None, optional): Optional axis to partition over.
|
||||
If ``None``, this partitions over the flattened array.
|
||||
If unspecified, it defaults to ``-1``.
|
||||
|
||||
@ -2426,13 +2426,13 @@ void init_ops(py::module_& m) {
|
||||
R"pbdoc(
|
||||
repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Repeate an array along a specified axis.
|
||||
Repeat an array along a specified axis.
|
||||
|
||||
Args:
|
||||
array (array): Input array.
|
||||
repeats (int): The number of repetitions for each element.
|
||||
axis (int, optional): The axis in which to repeat the array along. If
|
||||
unspecified it uses the flattened array of the input and repeates
|
||||
unspecified it uses the flattened array of the input and repeats
|
||||
along axis 0.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||
|
||||
@ -3050,7 +3050,7 @@ void init_ops(py::module_& m) {
|
||||
|
||||
Round to the given number of decimals.
|
||||
|
||||
Bascially performs:
|
||||
Basically performs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -212,7 +212,7 @@ void init_random(py::module_& parent_module) {
|
||||
upper (scalar or array): Upper bound of the domain.
|
||||
shape (list(int), optional): The shape of the output.
|
||||
Default is ``()``.
|
||||
dtype (Dtype, optinoal): The data type of the output.
|
||||
dtype (Dtype, optional): The data type of the output.
|
||||
Default is ``float32``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
|
||||
|
@ -952,7 +952,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b_mx = a_mx[25:-50:-3]
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
|
||||
# Negatie slice and ascending bounds
|
||||
# Negative slice and ascending bounds
|
||||
b_np = a_np[0:20:-3]
|
||||
b_mx = a_mx[0:20:-3]
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
|
@ -53,10 +53,10 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
for dtype in self.dtypes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
base_shapes = [4, 8, 16, 32, 64, 128]
|
||||
pertubations = [-2, -1, 0, 1, 2]
|
||||
perturbations = [-2, -1, 0, 1, 2]
|
||||
|
||||
for dim in base_shapes:
|
||||
for p in pertubations:
|
||||
for p in perturbations:
|
||||
shape_a = (dim + p, dim + p)
|
||||
shape_b = (dim + p, dim + p)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
@ -81,12 +81,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
|
||||
with self.subTest(tranpose="nn"):
|
||||
with self.subTest(transpose="nn"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
with self.subTest(tranpose="nt"):
|
||||
with self.subTest(transpose="nt"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
@ -97,7 +97,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tn"):
|
||||
with self.subTest(transpose="tn"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(
|
||||
@ -108,7 +108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tt"):
|
||||
with self.subTest(transpose="tt"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
@ -191,7 +191,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Batched matmul with simple broadast
|
||||
# Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ b_npy
|
||||
@ -213,7 +213,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
||||
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
||||
|
||||
# Batched and transposed matmul with simple broadast
|
||||
# Batched and transposed matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
@ -88,7 +88,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
self.assertEqual(a.item(), 3.0)
|
||||
|
||||
# Check comibinations with mlx arrays
|
||||
# Check combinations with mlx arrays
|
||||
a = mx.add(mx.array(True), False)
|
||||
self.assertEqual(a.dtype, mx.bool_)
|
||||
self.assertEqual(a.item(), True)
|
||||
|
@ -76,7 +76,7 @@ TEST_CASE("test arg reduce small") {
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
INFO("Skipping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
|
||||
@ -106,7 +106,7 @@ TEST_CASE("test arg reduce small") {
|
||||
|
||||
TEST_CASE("test arg reduce against cpu") {
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
INFO("Skipping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ void test_arg_reduce_small_bool(
|
||||
|
||||
TEST_CASE("test arg reduce bool") {
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
INFO("Skipping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
auto x = array(
|
||||
@ -201,7 +201,7 @@ TEST_CASE("test arg reduce irregular strides") {
|
||||
Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2});
|
||||
|
||||
if (!metal::is_available()) {
|
||||
INFO("Skiping arg reduction gpu tests");
|
||||
INFO("Skipping arg reduction gpu tests");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -989,7 +989,7 @@ TEST_CASE("test as_strided grads") {
|
||||
}
|
||||
|
||||
TEST_CASE("test jvp from vjp") {
|
||||
// Unary elementwise ops
|
||||
// Unary element-wise ops
|
||||
{
|
||||
auto x = random::uniform({5, 10});
|
||||
eval(x);
|
||||
@ -1022,7 +1022,7 @@ TEST_CASE("test jvp from vjp") {
|
||||
CHECK(compute_derivs(mlx::core::rsqrt));
|
||||
}
|
||||
|
||||
// Binary elementwise ops
|
||||
// Binary element-wise ops
|
||||
{
|
||||
auto x = random::uniform({5, 10});
|
||||
auto y = random::uniform({5, 10});
|
||||
|
@ -7,7 +7,7 @@
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test arange") {
|
||||
// Check type is inferred correclty
|
||||
// Check type is inferred correctly
|
||||
{
|
||||
auto x = arange(10);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
@ -1411,7 +1411,7 @@ TEST_CASE("test broadcast") {
|
||||
x.eval();
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0, 1});
|
||||
|
||||
// Broadcast on transposed arrray works
|
||||
// Broadcast on transposed array works
|
||||
x = array({0, 1, 2, 3, 4, 5}, {2, 3});
|
||||
x = broadcast_to(transpose(x), {2, 3, 2});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 3, 2});
|
||||
@ -1733,7 +1733,7 @@ TEST_CASE("test scatter") {
|
||||
out = scatter(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
|
||||
|
||||
// Irregular strided index and reduce collison test
|
||||
// Irregular strided index and reduce collision test
|
||||
in = zeros({10}, float32);
|
||||
inds = broadcast_to(array(3), {10});
|
||||
updates = ones({10, 1}, float32);
|
||||
@ -1750,7 +1750,7 @@ TEST_CASE("test scatter") {
|
||||
out = scatter_max(array(1), {}, array(2), std::vector<int>{});
|
||||
CHECK_EQ(out.item<int>(), 2);
|
||||
|
||||
// Irregularaly strided updates test
|
||||
// Irregularly strided updates test
|
||||
in = ones({3, 3});
|
||||
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
|
||||
inds = array({0});
|
||||
|
Loading…
Reference in New Issue
Block a user