From 36427f6126a1405d87da8b665731f82eaaab5cf0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 10 Aug 2024 09:24:35 -0700 Subject: [PATCH] docs update --- docs/build/html/.buildinfo | 2 +- docs/build/html/_sources/dev/extensions.rst | 5 +- .../_sources/examples/llama-inference.rst | 2 +- docs/build/html/_sources/examples/mlp.rst | 2 +- docs/build/html/_sources/install.rst | 26 +- .../_autosummary/mlx.core.DtypeCategory.rst | 19 +- .../mlx.core.linalg.cholesky_inv.rst | 6 + .../_autosummary/mlx.core.linalg.tri_inv.rst | 6 + docs/build/html/_sources/python/linalg.rst | 2 + .../html/_static/documentation_options.js | 2 +- docs/build/html/annotated.html | 299 ++- docs/build/html/array_8h_source.html | 32 +- .../backend_2accelerate_2utils_8h_source.html | 4 +- .../build/html/backend_2metal_2device_8h.html | 4 - .../backend_2metal_2device_8h_source.html | 414 ++-- docs/build/html/backend_2metal_2utils_8h.html | 26 +- .../html/backend_2metal_2utils_8h_source.html | 226 +-- .../html/class_m_p_s_1_1_kernel-members.html | 92 - docs/build/html/class_m_p_s_1_1_kernel.html | 144 -- docs/build/html/class_m_p_s_1_1_kernel.png | Bin 656 -> 0 bytes .../html/class_m_p_s_1_1_matrix-members.html | 93 - docs/build/html/class_m_p_s_1_1_matrix.html | 183 -- docs/build/html/class_m_p_s_1_1_matrix.png | Bin 665 -> 0 bytes ...s_m_p_s_1_1_matrix_descriptor-members.html | 93 - .../class_m_p_s_1_1_matrix_descriptor.html | 221 --- .../class_m_p_s_1_1_matrix_descriptor.png | Bin 813 -> 0 bytes ...p_s_1_1_matrix_multiplication-members.html | 98 - ...class_m_p_s_1_1_matrix_multiplication.html | 318 --- .../class_m_p_s_1_1_matrix_multiplication.png | Bin 1041 -> 0 bytes ..._matrix_vector_multiplication-members.html | 93 - ..._p_s_1_1_matrix_vector_multiplication.html | 213 -- ...m_p_s_1_1_matrix_vector_multiplication.png | Bin 1145 -> 0 bytes .../html/class_m_p_s_1_1_vector-members.html | 93 - docs/build/html/class_m_p_s_1_1_vector.html | 183 -- docs/build/html/class_m_p_s_1_1_vector.png | Bin 658 -> 0 bytes ...s_m_p_s_1_1_vector_descriptor-members.html | 92 - .../class_m_p_s_1_1_vector_descriptor.html | 178 -- .../class_m_p_s_1_1_vector_descriptor.png | Bin 794 -> 0 bytes docs/build/html/classes.html | 10 +- .../classmlx_1_1core_1_1_inverse-members.html | 2 +- .../html/classmlx_1_1core_1_1_inverse.html | 19 +- .../html/classmlx_1_1core_1_1_primitive.html | 6 +- .../classmlx_1_1core_1_1_scatter-members.html | 2 +- .../html/classmlx_1_1core_1_1_scatter.html | 40 +- ...e_1_1fast_1_1_affine_quantize-members.html | 109 ++ ...x_1_1core_1_1fast_1_1_affine_quantize.html | 286 +++ ...lx_1_1core_1_1fast_1_1_affine_quantize.png | Bin 0 -> 1003 bytes .../classmlx_1_1core_1_1fast_1_1_custom.html | 15 +- .../classmlx_1_1core_1_1fast_1_1_custom.png | Bin 2664 -> 3025 bytes ...x_1_1core_1_1metal_1_1_device-members.html | 2 +- .../classmlx_1_1core_1_1metal_1_1_device.html | 14 +- docs/build/html/common_2arange_8h_source.html | 26 +- docs/build/html/common_2binary_8h_source.html | 26 +- .../html/common_2binary__two_8h_source.html | 26 +- .../html/common_2hadamard_8h_source.html | 2 +- docs/build/html/common_2unary_8h_source.html | 26 +- docs/build/html/compiled_8h_source.html | 4 +- docs/build/html/cpp/ops.html | 42 +- docs/build/html/dev/extensions.html | 15 +- docs/build/html/dev/metal_debugger.html | 10 +- .../dir_4336740ec0075891704443b417fef6cb.html | 2 + .../dir_6768c99e6145fb9510ccdb40db8ede25.html | 2 +- .../dir_70a37effa88bcbd6b791977fa1e64356.html | 2 + .../dir_d0c977ea65824390717cdb7efc36c157.html | 2 - docs/build/html/doxygen_crawl.html | 41 +- docs/build/html/dtype_8h.html | 6 - docs/build/html/dtype_8h_source.html | 269 ++- .../html/examples/linear_regression.html | 10 +- docs/build/html/examples/llama-inference.html | 12 +- docs/build/html/examples/mlp.html | 12 +- docs/build/html/fast_8h.html | 6 + docs/build/html/fast_8h_source.html | 26 +- docs/build/html/fast__primitives_8h.html | 2 + .../html/fast__primitives_8h_source.html | 44 +- docs/build/html/files.html | 146 +- docs/build/html/functions_a.html | 2 +- docs/build/html/functions_b.html | 2 + docs/build/html/functions_d.html | 4 +- docs/build/html/functions_e.html | 5 +- docs/build/html/functions_func_a.html | 2 +- docs/build/html/functions_func_d.html | 4 +- docs/build/html/functions_func_e.html | 5 +- docs/build/html/functions_func_i.html | 3 +- docs/build/html/functions_func_l.html | 6 +- docs/build/html/functions_func_m.html | 1 - docs/build/html/functions_func_o.html | 6 +- docs/build/html/functions_func_r.html | 3 +- docs/build/html/functions_func_s.html | 5 - docs/build/html/functions_func_v.html | 3 +- docs/build/html/functions_h.html | 4 + docs/build/html/functions_i.html | 4 +- docs/build/html/functions_l.html | 8 +- docs/build/html/functions_m.html | 1 - docs/build/html/functions_n.html | 1 + docs/build/html/functions_o.html | 4 +- docs/build/html/functions_r.html | 3 +- docs/build/html/functions_s.html | 5 - docs/build/html/functions_t.html | 8 +- docs/build/html/functions_v.html | 3 +- docs/build/html/functions_vars_b.html | 2 + ...a10604f9f12.html => functions_vars_h.html} | 25 +- docs/build/html/functions_vars_n.html | 1 + docs/build/html/functions_vars_t.html | 8 +- ...steel_2gemm_2gemm_8h.html => gemm_8h.html} | 2 +- ...emm_8h_source.html => gemm_8h_source.html} | 2 +- docs/build/html/genindex.html | 20 +- docs/build/html/globals.html | 2 - docs/build/html/globals_a.html | 3 + docs/build/html/globals_b.html | 3 + docs/build/html/globals_c.html | 2 + docs/build/html/globals_defs.html | 6 +- docs/build/html/globals_func_a.html | 3 + docs/build/html/globals_func_b.html | 3 + docs/build/html/globals_func_c.html | 2 + docs/build/html/globals_func_g.html | 2 + docs/build/html/globals_func_t.html | 1 + docs/build/html/globals_func_u.html | 1 + docs/build/html/globals_g.html | 3 + docs/build/html/globals_m.html | 4 +- docs/build/html/globals_n.html | 2 +- docs/build/html/globals_t.html | 1 + docs/build/html/globals_type.html | 2 +- docs/build/html/globals_u.html | 1 + docs/build/html/globals_vars.html | 1 + docs/build/html/group__ops.html | 58 +- docs/build/html/hierarchy.html | 482 +++-- docs/build/html/includes_8h.html | 2 + docs/build/html/includes_8h_source.html | 6 +- docs/build/html/index.html | 10 +- docs/build/html/install.html | 36 +- docs/build/html/jit_2gemv__masked_8h.html | 143 ++ .../html/jit_2gemv__masked_8h_source.html | 118 ++ docs/build/html/kernels_2gemv__masked_8h.html | 396 ++++ .../html/kernels_2gemv__masked_8h_source.html | 972 +++++++++ .../html/kernels_2softmax_8h_source.html | 8 +- docs/build/html/kernels_8h.html | 2 + docs/build/html/kernels_8h_source.html | 110 +- docs/build/html/linalg_8h.html | 4 + docs/build/html/linalg_8h_source.html | 10 +- docs/build/html/matmul_8h.html | 9 +- docs/build/html/matmul_8h_source.html | 94 +- docs/build/html/menudata.js | 7 +- docs/build/html/metal_2binary_8h.html | 16 +- docs/build/html/metal_2binary_8h_source.html | 20 +- .../build/html/metal_2kernels_2binary_8h.html | 123 ++ .../metal_2kernels_2binary_8h_source.html | 202 +- .../html/metal_2kernels_2binary__two_8h.html | 138 ++ ...metal_2kernels_2binary__two_8h_source.html | 239 ++- docs/build/html/metal_2kernels_2copy_8h.html | 72 + .../html/metal_2kernels_2copy_8h_source.html | 298 +-- .../html/metal_2kernels_2ternary_8h.html | 46 + .../metal_2kernels_2ternary_8h_source.html | 203 +- docs/build/html/metal_2kernels_2unary_8h.html | 36 + .../html/metal_2kernels_2unary_8h_source.html | 31 +- docs/build/html/metal_2reduce_8h_source.html | 4 +- docs/build/html/mps_2gemm_8h.html | 235 --- docs/build/html/mps_2gemm_8h_source.html | 572 ------ docs/build/html/namespace_m_p_s.html | 160 -- docs/build/html/namespace_m_t_l.html | 91 - .../html/namespace_m_t_l_1_1_private.html | 97 - ...namespace_m_t_l_1_1_private_1_1_class.html | 227 --- ...espace_m_t_l_1_1_private_1_1_selector.html | 439 ----- docs/build/html/namespacemembers.html | 46 +- docs/build/html/namespacemembers_a.html | 124 -- docs/build/html/namespacemembers_b.html | 4 +- docs/build/html/namespacemembers_c.html | 1 + docs/build/html/namespacemembers_d.html | 10 +- docs/build/html/namespacemembers_enum.html | 1 - docs/build/html/namespacemembers_eval.html | 5 - docs/build/html/namespacemembers_func.html | 45 +- docs/build/html/namespacemembers_func_a.html | 123 -- docs/build/html/namespacemembers_func_b.html | 4 +- docs/build/html/namespacemembers_func_c.html | 1 + docs/build/html/namespacemembers_func_d.html | 4 +- docs/build/html/namespacemembers_func_g.html | 6 +- docs/build/html/namespacemembers_func_i.html | 1 - docs/build/html/namespacemembers_func_m.html | 1 + docs/build/html/namespacemembers_func_p.html | 2 +- docs/build/html/namespacemembers_func_s.html | 1 + docs/build/html/namespacemembers_func_t.html | 2 + docs/build/html/namespacemembers_g.html | 6 +- docs/build/html/namespacemembers_i.html | 1 - docs/build/html/namespacemembers_m.html | 1 + docs/build/html/namespacemembers_p.html | 2 +- docs/build/html/namespacemembers_s.html | 1 + docs/build/html/namespacemembers_t.html | 2 + docs/build/html/namespacemembers_type.html | 1 - docs/build/html/namespacemlx_1_1core.html | 464 ++++- .../html/namespacemlx_1_1core_1_1fast.html | 121 ++ .../html/namespacemlx_1_1core_1_1linalg.html | 56 + .../html/namespacemlx_1_1core_1_1metal.html | 46 +- docs/build/html/namespaces.html | 95 +- docs/build/html/objects.inv | Bin 25204 -> 25261 bytes docs/build/html/ops_8h.html | 20 +- docs/build/html/ops_8h_source.html | 1740 +++++++++-------- docs/build/html/primitives_8h_source.html | 1028 +++++----- .../python/_autosummary/mlx.core.Device.html | 10 +- .../python/_autosummary/mlx.core.Dtype.html | 10 +- .../_autosummary/mlx.core.DtypeCategory.html | 41 +- .../python/_autosummary/mlx.core.Stream.html | 10 +- .../python/_autosummary/mlx.core.abs.html | 10 +- .../python/_autosummary/mlx.core.add.html | 10 +- .../python/_autosummary/mlx.core.addmm.html | 10 +- .../python/_autosummary/mlx.core.all.html | 10 +- .../_autosummary/mlx.core.allclose.html | 10 +- .../python/_autosummary/mlx.core.any.html | 10 +- .../python/_autosummary/mlx.core.arange.html | 10 +- .../python/_autosummary/mlx.core.arccos.html | 10 +- .../python/_autosummary/mlx.core.arccosh.html | 10 +- .../python/_autosummary/mlx.core.arcsin.html | 10 +- .../python/_autosummary/mlx.core.arcsinh.html | 10 +- .../python/_autosummary/mlx.core.arctan.html | 10 +- .../python/_autosummary/mlx.core.arctan2.html | 10 +- .../python/_autosummary/mlx.core.arctanh.html | 10 +- .../python/_autosummary/mlx.core.argmax.html | 10 +- .../python/_autosummary/mlx.core.argmin.html | 10 +- .../_autosummary/mlx.core.argpartition.html | 10 +- .../python/_autosummary/mlx.core.argsort.html | 10 +- .../python/_autosummary/mlx.core.array.T.html | 10 +- .../_autosummary/mlx.core.array.abs.html | 10 +- .../_autosummary/mlx.core.array.all.html | 10 +- .../_autosummary/mlx.core.array.any.html | 10 +- .../_autosummary/mlx.core.array.argmax.html | 10 +- .../_autosummary/mlx.core.array.argmin.html | 10 +- .../_autosummary/mlx.core.array.astype.html | 10 +- .../_autosummary/mlx.core.array.at.html | 10 +- .../_autosummary/mlx.core.array.conj.html | 10 +- .../_autosummary/mlx.core.array.cos.html | 10 +- .../_autosummary/mlx.core.array.cummax.html | 10 +- .../_autosummary/mlx.core.array.cummin.html | 10 +- .../_autosummary/mlx.core.array.cumprod.html | 10 +- .../_autosummary/mlx.core.array.cumsum.html | 10 +- .../_autosummary/mlx.core.array.diag.html | 10 +- .../_autosummary/mlx.core.array.diagonal.html | 10 +- .../_autosummary/mlx.core.array.dtype.html | 10 +- .../_autosummary/mlx.core.array.exp.html | 10 +- .../_autosummary/mlx.core.array.flatten.html | 10 +- .../python/_autosummary/mlx.core.array.html | 10 +- .../_autosummary/mlx.core.array.item.html | 10 +- .../_autosummary/mlx.core.array.itemsize.html | 10 +- .../_autosummary/mlx.core.array.log.html | 10 +- .../_autosummary/mlx.core.array.log10.html | 10 +- .../_autosummary/mlx.core.array.log1p.html | 10 +- .../_autosummary/mlx.core.array.log2.html | 10 +- .../mlx.core.array.logsumexp.html | 10 +- .../_autosummary/mlx.core.array.max.html | 10 +- .../_autosummary/mlx.core.array.mean.html | 10 +- .../_autosummary/mlx.core.array.min.html | 10 +- .../_autosummary/mlx.core.array.moveaxis.html | 10 +- .../_autosummary/mlx.core.array.nbytes.html | 10 +- .../_autosummary/mlx.core.array.ndim.html | 10 +- .../_autosummary/mlx.core.array.prod.html | 10 +- .../mlx.core.array.reciprocal.html | 10 +- .../_autosummary/mlx.core.array.reshape.html | 10 +- .../_autosummary/mlx.core.array.round.html | 10 +- .../_autosummary/mlx.core.array.rsqrt.html | 10 +- .../_autosummary/mlx.core.array.shape.html | 10 +- .../_autosummary/mlx.core.array.sin.html | 10 +- .../_autosummary/mlx.core.array.size.html | 10 +- .../_autosummary/mlx.core.array.split.html | 10 +- .../_autosummary/mlx.core.array.sqrt.html | 10 +- .../_autosummary/mlx.core.array.square.html | 10 +- .../_autosummary/mlx.core.array.squeeze.html | 10 +- .../_autosummary/mlx.core.array.sum.html | 10 +- .../_autosummary/mlx.core.array.swapaxes.html | 10 +- .../_autosummary/mlx.core.array.tolist.html | 10 +- .../mlx.core.array.transpose.html | 10 +- .../_autosummary/mlx.core.array.var.html | 10 +- .../_autosummary/mlx.core.array.view.html | 10 +- .../_autosummary/mlx.core.array_equal.html | 10 +- .../_autosummary/mlx.core.as_strided.html | 10 +- .../_autosummary/mlx.core.atleast_1d.html | 10 +- .../_autosummary/mlx.core.atleast_2d.html | 10 +- .../_autosummary/mlx.core.atleast_3d.html | 10 +- .../_autosummary/mlx.core.bitwise_and.html | 10 +- .../_autosummary/mlx.core.bitwise_or.html | 10 +- .../_autosummary/mlx.core.bitwise_xor.html | 10 +- .../mlx.core.block_masked_mm.html | 10 +- .../_autosummary/mlx.core.broadcast_to.html | 10 +- .../python/_autosummary/mlx.core.ceil.html | 10 +- .../python/_autosummary/mlx.core.clip.html | 10 +- .../python/_autosummary/mlx.core.compile.html | 10 +- .../_autosummary/mlx.core.concatenate.html | 10 +- .../python/_autosummary/mlx.core.conj.html | 10 +- .../_autosummary/mlx.core.conjugate.html | 10 +- .../python/_autosummary/mlx.core.conv1d.html | 10 +- .../python/_autosummary/mlx.core.conv2d.html | 10 +- .../_autosummary/mlx.core.conv_general.html | 10 +- .../_autosummary/mlx.core.convolve.html | 10 +- .../python/_autosummary/mlx.core.cos.html | 10 +- .../python/_autosummary/mlx.core.cosh.html | 10 +- .../python/_autosummary/mlx.core.cummax.html | 10 +- .../python/_autosummary/mlx.core.cummin.html | 10 +- .../python/_autosummary/mlx.core.cumprod.html | 10 +- .../python/_autosummary/mlx.core.cumsum.html | 10 +- .../mlx.core.custom_function.html | 10 +- .../_autosummary/mlx.core.default_device.html | 10 +- .../_autosummary/mlx.core.default_stream.html | 10 +- .../python/_autosummary/mlx.core.degrees.html | 10 +- .../_autosummary/mlx.core.dequantize.html | 10 +- .../python/_autosummary/mlx.core.diag.html | 10 +- .../_autosummary/mlx.core.diagonal.html | 10 +- .../mlx.core.disable_compile.html | 10 +- .../mlx.core.distributed.Group.html | 10 +- .../mlx.core.distributed.all_gather.html | 10 +- .../mlx.core.distributed.all_sum.html | 10 +- .../mlx.core.distributed.init.html | 10 +- .../mlx.core.distributed.is_available.html | 10 +- .../python/_autosummary/mlx.core.divide.html | 10 +- .../python/_autosummary/mlx.core.divmod.html | 10 +- .../python/_autosummary/mlx.core.einsum.html | 10 +- .../_autosummary/mlx.core.einsum_path.html | 10 +- .../_autosummary/mlx.core.enable_compile.html | 10 +- .../python/_autosummary/mlx.core.equal.html | 10 +- .../python/_autosummary/mlx.core.erf.html | 10 +- .../python/_autosummary/mlx.core.erfinv.html | 10 +- .../python/_autosummary/mlx.core.eval.html | 10 +- .../python/_autosummary/mlx.core.exp.html | 10 +- .../_autosummary/mlx.core.expand_dims.html | 10 +- .../python/_autosummary/mlx.core.expm1.html | 10 +- .../python/_autosummary/mlx.core.eye.html | 10 +- .../mlx.core.fast.layer_norm.html | 10 +- .../_autosummary/mlx.core.fast.rms_norm.html | 10 +- .../_autosummary/mlx.core.fast.rope.html | 10 +- ...ore.fast.scaled_dot_product_attention.html | 10 +- .../python/_autosummary/mlx.core.fft.fft.html | 10 +- .../_autosummary/mlx.core.fft.fft2.html | 10 +- .../_autosummary/mlx.core.fft.fftn.html | 10 +- .../_autosummary/mlx.core.fft.ifft.html | 10 +- .../_autosummary/mlx.core.fft.ifft2.html | 10 +- .../_autosummary/mlx.core.fft.ifftn.html | 10 +- .../_autosummary/mlx.core.fft.irfft.html | 10 +- .../_autosummary/mlx.core.fft.irfft2.html | 10 +- .../_autosummary/mlx.core.fft.irfftn.html | 10 +- .../_autosummary/mlx.core.fft.rfft.html | 10 +- .../_autosummary/mlx.core.fft.rfft2.html | 10 +- .../_autosummary/mlx.core.fft.rfftn.html | 10 +- .../python/_autosummary/mlx.core.flatten.html | 10 +- .../python/_autosummary/mlx.core.floor.html | 10 +- .../_autosummary/mlx.core.floor_divide.html | 10 +- .../python/_autosummary/mlx.core.full.html | 10 +- .../_autosummary/mlx.core.gather_mm.html | 10 +- .../_autosummary/mlx.core.gather_qmm.html | 10 +- .../python/_autosummary/mlx.core.grad.html | 10 +- .../python/_autosummary/mlx.core.greater.html | 10 +- .../_autosummary/mlx.core.greater_equal.html | 10 +- .../mlx.core.hadamard_transform.html | 10 +- .../_autosummary/mlx.core.identity.html | 10 +- .../python/_autosummary/mlx.core.inner.html | 10 +- .../python/_autosummary/mlx.core.isclose.html | 10 +- .../python/_autosummary/mlx.core.isinf.html | 10 +- .../python/_autosummary/mlx.core.isnan.html | 10 +- .../_autosummary/mlx.core.isneginf.html | 10 +- .../_autosummary/mlx.core.isposinf.html | 10 +- .../_autosummary/mlx.core.issubdtype.html | 10 +- .../python/_autosummary/mlx.core.jvp.html | 10 +- .../_autosummary/mlx.core.left_shift.html | 10 +- .../python/_autosummary/mlx.core.less.html | 10 +- .../_autosummary/mlx.core.less_equal.html | 10 +- .../mlx.core.linalg.cholesky.html | 18 +- .../mlx.core.linalg.cholesky_inv.html | 984 ++++++++++ .../_autosummary/mlx.core.linalg.inv.html | 16 +- .../_autosummary/mlx.core.linalg.norm.html | 16 +- .../_autosummary/mlx.core.linalg.qr.html | 16 +- .../_autosummary/mlx.core.linalg.svd.html | 10 +- .../_autosummary/mlx.core.linalg.tri_inv.html | 975 +++++++++ .../_autosummary/mlx.core.linspace.html | 10 +- .../python/_autosummary/mlx.core.load.html | 10 +- .../python/_autosummary/mlx.core.log.html | 10 +- .../python/_autosummary/mlx.core.log10.html | 10 +- .../python/_autosummary/mlx.core.log1p.html | 10 +- .../python/_autosummary/mlx.core.log2.html | 10 +- .../_autosummary/mlx.core.logaddexp.html | 10 +- .../_autosummary/mlx.core.logical_and.html | 10 +- .../_autosummary/mlx.core.logical_not.html | 10 +- .../_autosummary/mlx.core.logical_or.html | 10 +- .../_autosummary/mlx.core.logsumexp.html | 10 +- .../python/_autosummary/mlx.core.matmul.html | 10 +- .../python/_autosummary/mlx.core.max.html | 10 +- .../python/_autosummary/mlx.core.maximum.html | 10 +- .../python/_autosummary/mlx.core.mean.html | 10 +- .../_autosummary/mlx.core.meshgrid.html | 10 +- .../mlx.core.metal.clear_cache.html | 10 +- .../mlx.core.metal.device_info.html | 10 +- .../mlx.core.metal.get_active_memory.html | 10 +- .../mlx.core.metal.get_cache_memory.html | 10 +- .../mlx.core.metal.get_peak_memory.html | 10 +- .../mlx.core.metal.is_available.html | 10 +- .../mlx.core.metal.reset_peak_memory.html | 10 +- .../mlx.core.metal.set_cache_limit.html | 10 +- .../mlx.core.metal.set_memory_limit.html | 10 +- .../mlx.core.metal.start_capture.html | 10 +- .../mlx.core.metal.stop_capture.html | 10 +- .../python/_autosummary/mlx.core.min.html | 10 +- .../python/_autosummary/mlx.core.minimum.html | 10 +- .../_autosummary/mlx.core.moveaxis.html | 10 +- .../_autosummary/mlx.core.multiply.html | 10 +- .../_autosummary/mlx.core.nan_to_num.html | 10 +- .../_autosummary/mlx.core.negative.html | 10 +- .../_autosummary/mlx.core.new_stream.html | 10 +- .../_autosummary/mlx.core.not_equal.html | 10 +- .../python/_autosummary/mlx.core.ones.html | 10 +- .../_autosummary/mlx.core.ones_like.html | 10 +- .../python/_autosummary/mlx.core.outer.html | 10 +- .../python/_autosummary/mlx.core.pad.html | 15 +- .../_autosummary/mlx.core.partition.html | 10 +- .../python/_autosummary/mlx.core.power.html | 10 +- .../python/_autosummary/mlx.core.prod.html | 10 +- .../_autosummary/mlx.core.quantize.html | 10 +- .../mlx.core.quantized_matmul.html | 10 +- .../python/_autosummary/mlx.core.radians.html | 10 +- .../mlx.core.random.bernoulli.html | 10 +- .../mlx.core.random.categorical.html | 10 +- .../_autosummary/mlx.core.random.gumbel.html | 10 +- .../_autosummary/mlx.core.random.key.html | 10 +- .../_autosummary/mlx.core.random.laplace.html | 10 +- .../mlx.core.random.multivariate_normal.html | 10 +- .../_autosummary/mlx.core.random.normal.html | 10 +- .../_autosummary/mlx.core.random.randint.html | 10 +- .../_autosummary/mlx.core.random.seed.html | 10 +- .../_autosummary/mlx.core.random.split.html | 10 +- .../mlx.core.random.truncated_normal.html | 10 +- .../_autosummary/mlx.core.random.uniform.html | 10 +- .../_autosummary/mlx.core.reciprocal.html | 10 +- .../_autosummary/mlx.core.remainder.html | 10 +- .../python/_autosummary/mlx.core.repeat.html | 10 +- .../python/_autosummary/mlx.core.reshape.html | 10 +- .../_autosummary/mlx.core.right_shift.html | 10 +- .../python/_autosummary/mlx.core.round.html | 10 +- .../python/_autosummary/mlx.core.rsqrt.html | 10 +- .../python/_autosummary/mlx.core.save.html | 10 +- .../_autosummary/mlx.core.save_gguf.html | 10 +- .../mlx.core.save_safetensors.html | 10 +- .../python/_autosummary/mlx.core.savez.html | 10 +- .../mlx.core.savez_compressed.html | 10 +- .../mlx.core.set_default_device.html | 10 +- .../mlx.core.set_default_stream.html | 10 +- .../python/_autosummary/mlx.core.sigmoid.html | 10 +- .../python/_autosummary/mlx.core.sign.html | 10 +- .../python/_autosummary/mlx.core.sin.html | 10 +- .../python/_autosummary/mlx.core.sinh.html | 10 +- .../python/_autosummary/mlx.core.softmax.html | 10 +- .../python/_autosummary/mlx.core.sort.html | 10 +- .../python/_autosummary/mlx.core.split.html | 10 +- .../python/_autosummary/mlx.core.sqrt.html | 10 +- .../python/_autosummary/mlx.core.square.html | 10 +- .../python/_autosummary/mlx.core.squeeze.html | 10 +- .../python/_autosummary/mlx.core.stack.html | 10 +- .../python/_autosummary/mlx.core.std.html | 10 +- .../_autosummary/mlx.core.stop_gradient.html | 10 +- .../_autosummary/mlx.core.subtract.html | 10 +- .../python/_autosummary/mlx.core.sum.html | 10 +- .../_autosummary/mlx.core.swapaxes.html | 10 +- .../_autosummary/mlx.core.synchronize.html | 10 +- .../python/_autosummary/mlx.core.take.html | 10 +- .../mlx.core.take_along_axis.html | 10 +- .../python/_autosummary/mlx.core.tan.html | 10 +- .../python/_autosummary/mlx.core.tanh.html | 10 +- .../_autosummary/mlx.core.tensordot.html | 10 +- .../python/_autosummary/mlx.core.tile.html | 10 +- .../python/_autosummary/mlx.core.topk.html | 10 +- .../python/_autosummary/mlx.core.trace.html | 10 +- .../_autosummary/mlx.core.transpose.html | 10 +- .../python/_autosummary/mlx.core.tri.html | 10 +- .../python/_autosummary/mlx.core.tril.html | 10 +- .../python/_autosummary/mlx.core.triu.html | 10 +- .../_autosummary/mlx.core.value_and_grad.html | 10 +- .../python/_autosummary/mlx.core.var.html | 10 +- .../python/_autosummary/mlx.core.view.html | 10 +- .../python/_autosummary/mlx.core.vjp.html | 10 +- .../python/_autosummary/mlx.core.vmap.html | 10 +- .../python/_autosummary/mlx.core.where.html | 10 +- .../python/_autosummary/mlx.core.zeros.html | 10 +- .../_autosummary/mlx.core.zeros_like.html | 10 +- .../python/_autosummary/mlx.nn.quantize.html | 10 +- .../_autosummary/mlx.nn.value_and_grad.html | 10 +- .../mlx.optimizers.clip_grad_norm.html | 10 +- .../_autosummary/mlx.utils.tree_flatten.html | 10 +- .../_autosummary/mlx.utils.tree_map.html | 10 +- .../mlx.utils.tree_map_with_path.html | 10 +- .../_autosummary/mlx.utils.tree_reduce.html | 10 +- .../mlx.utils.tree_unflatten.html | 10 +- .../python/_autosummary/stream_class.html | 10 +- docs/build/html/python/array.html | 10 +- docs/build/html/python/data_types.html | 12 +- .../html/python/devices_and_streams.html | 10 +- docs/build/html/python/distributed.html | 10 +- docs/build/html/python/fast.html | 10 +- docs/build/html/python/fft.html | 10 +- docs/build/html/python/linalg.html | 20 +- docs/build/html/python/metal.html | 10 +- docs/build/html/python/nn.html | 10 +- .../python/nn/_autosummary/mlx.nn.ALiBi.html | 10 +- .../nn/_autosummary/mlx.nn.AvgPool1d.html | 10 +- .../nn/_autosummary/mlx.nn.AvgPool2d.html | 10 +- .../nn/_autosummary/mlx.nn.BatchNorm.html | 10 +- .../python/nn/_autosummary/mlx.nn.Conv1d.html | 10 +- .../python/nn/_autosummary/mlx.nn.Conv2d.html | 10 +- .../python/nn/_autosummary/mlx.nn.Conv3d.html | 10 +- .../nn/_autosummary/mlx.nn.Dropout.html | 10 +- .../nn/_autosummary/mlx.nn.Dropout2d.html | 10 +- .../nn/_autosummary/mlx.nn.Dropout3d.html | 10 +- .../nn/_autosummary/mlx.nn.Embedding.html | 10 +- .../python/nn/_autosummary/mlx.nn.GELU.html | 17 +- .../python/nn/_autosummary/mlx.nn.GLU.html | 10 +- .../python/nn/_autosummary/mlx.nn.GRU.html | 10 +- .../nn/_autosummary/mlx.nn.GroupNorm.html | 10 +- .../nn/_autosummary/mlx.nn.HardShrink.html | 10 +- .../nn/_autosummary/mlx.nn.HardTanh.html | 10 +- .../nn/_autosummary/mlx.nn.Hardswish.html | 10 +- .../nn/_autosummary/mlx.nn.InstanceNorm.html | 10 +- .../python/nn/_autosummary/mlx.nn.LSTM.html | 10 +- .../nn/_autosummary/mlx.nn.LayerNorm.html | 10 +- .../nn/_autosummary/mlx.nn.LeakyReLU.html | 10 +- .../python/nn/_autosummary/mlx.nn.Linear.html | 10 +- .../nn/_autosummary/mlx.nn.MaxPool1d.html | 10 +- .../nn/_autosummary/mlx.nn.MaxPool2d.html | 10 +- .../python/nn/_autosummary/mlx.nn.Mish.html | 10 +- .../nn/_autosummary/mlx.nn.Module.apply.html | 10 +- .../mlx.nn.Module.apply_to_modules.html | 10 +- .../_autosummary/mlx.nn.Module.children.html | 10 +- .../nn/_autosummary/mlx.nn.Module.eval.html | 10 +- .../mlx.nn.Module.filter_and_map.html | 10 +- .../nn/_autosummary/mlx.nn.Module.freeze.html | 10 +- .../mlx.nn.Module.leaf_modules.html | 10 +- .../mlx.nn.Module.load_weights.html | 10 +- .../_autosummary/mlx.nn.Module.modules.html | 10 +- .../mlx.nn.Module.named_modules.html | 10 +- .../mlx.nn.Module.parameters.html | 10 +- .../mlx.nn.Module.save_weights.html | 10 +- .../_autosummary/mlx.nn.Module.set_dtype.html | 10 +- .../nn/_autosummary/mlx.nn.Module.state.html | 10 +- .../nn/_autosummary/mlx.nn.Module.train.html | 10 +- .../mlx.nn.Module.trainable_parameters.html | 10 +- .../_autosummary/mlx.nn.Module.training.html | 10 +- .../_autosummary/mlx.nn.Module.unfreeze.html | 10 +- .../nn/_autosummary/mlx.nn.Module.update.html | 10 +- .../mlx.nn.Module.update_modules.html | 10 +- .../mlx.nn.MultiHeadAttention.html | 10 +- .../python/nn/_autosummary/mlx.nn.PReLU.html | 10 +- .../mlx.nn.QuantizedEmbedding.html | 10 +- .../_autosummary/mlx.nn.QuantizedLinear.html | 10 +- .../nn/_autosummary/mlx.nn.RMSNorm.html | 10 +- .../python/nn/_autosummary/mlx.nn.RNN.html | 10 +- .../python/nn/_autosummary/mlx.nn.ReLU.html | 10 +- .../python/nn/_autosummary/mlx.nn.ReLU6.html | 10 +- .../python/nn/_autosummary/mlx.nn.RoPE.html | 10 +- .../python/nn/_autosummary/mlx.nn.SELU.html | 10 +- .../nn/_autosummary/mlx.nn.Sequential.html | 10 +- .../python/nn/_autosummary/mlx.nn.SiLU.html | 10 +- .../mlx.nn.SinusoidalPositionalEncoding.html | 10 +- .../nn/_autosummary/mlx.nn.Softmax.html | 10 +- .../nn/_autosummary/mlx.nn.Softmin.html | 10 +- .../nn/_autosummary/mlx.nn.Softplus.html | 10 +- .../nn/_autosummary/mlx.nn.Softshrink.html | 10 +- .../nn/_autosummary/mlx.nn.Softsign.html | 10 +- .../python/nn/_autosummary/mlx.nn.Step.html | 10 +- .../python/nn/_autosummary/mlx.nn.Tanh.html | 10 +- .../nn/_autosummary/mlx.nn.Transformer.html | 10 +- .../nn/_autosummary/mlx.nn.Upsample.html | 10 +- .../nn/_autosummary/mlx.nn.init.constant.html | 10 +- .../mlx.nn.init.glorot_normal.html | 10 +- .../mlx.nn.init.glorot_uniform.html | 10 +- .../_autosummary/mlx.nn.init.he_normal.html | 10 +- .../_autosummary/mlx.nn.init.he_uniform.html | 10 +- .../nn/_autosummary/mlx.nn.init.identity.html | 10 +- .../nn/_autosummary/mlx.nn.init.normal.html | 10 +- .../nn/_autosummary/mlx.nn.init.uniform.html | 10 +- .../nn/_autosummary_functions/mlx.nn.elu.html | 10 +- .../_autosummary_functions/mlx.nn.gelu.html | 10 +- .../mlx.nn.gelu_approx.html | 10 +- .../mlx.nn.gelu_fast_approx.html | 10 +- .../nn/_autosummary_functions/mlx.nn.glu.html | 10 +- .../mlx.nn.hard_shrink.html | 10 +- .../mlx.nn.hard_tanh.html | 10 +- .../mlx.nn.hardswish.html | 10 +- .../mlx.nn.leaky_relu.html | 10 +- .../mlx.nn.log_sigmoid.html | 10 +- .../mlx.nn.log_softmax.html | 10 +- .../mlx.nn.losses.binary_cross_entropy.html | 10 +- .../mlx.nn.losses.cosine_similarity_loss.html | 10 +- .../mlx.nn.losses.cross_entropy.html | 10 +- .../mlx.nn.losses.gaussian_nll_loss.html | 10 +- .../mlx.nn.losses.hinge_loss.html | 10 +- .../mlx.nn.losses.huber_loss.html | 10 +- .../mlx.nn.losses.kl_div_loss.html | 10 +- .../mlx.nn.losses.l1_loss.html | 10 +- .../mlx.nn.losses.log_cosh_loss.html | 10 +- .../mlx.nn.losses.margin_ranking_loss.html | 10 +- .../mlx.nn.losses.mse_loss.html | 10 +- .../mlx.nn.losses.nll_loss.html | 10 +- .../mlx.nn.losses.smooth_l1_loss.html | 10 +- .../mlx.nn.losses.triplet_loss.html | 10 +- .../_autosummary_functions/mlx.nn.mish.html | 10 +- .../_autosummary_functions/mlx.nn.prelu.html | 10 +- .../_autosummary_functions/mlx.nn.relu.html | 10 +- .../_autosummary_functions/mlx.nn.relu6.html | 10 +- .../_autosummary_functions/mlx.nn.selu.html | 10 +- .../mlx.nn.sigmoid.html | 10 +- .../_autosummary_functions/mlx.nn.silu.html | 10 +- .../mlx.nn.softmax.html | 10 +- .../mlx.nn.softmin.html | 10 +- .../mlx.nn.softplus.html | 10 +- .../mlx.nn.softshrink.html | 10 +- .../_autosummary_functions/mlx.nn.step.html | 10 +- .../_autosummary_functions/mlx.nn.tanh.html | 10 +- docs/build/html/python/nn/functions.html | 10 +- docs/build/html/python/nn/init.html | 10 +- docs/build/html/python/nn/layers.html | 10 +- docs/build/html/python/nn/losses.html | 10 +- docs/build/html/python/nn/module.html | 10 +- docs/build/html/python/ops.html | 12 +- docs/build/html/python/optimizers.html | 10 +- .../_autosummary/mlx.optimizers.AdaDelta.html | 10 +- .../mlx.optimizers.Adafactor.html | 10 +- .../_autosummary/mlx.optimizers.Adagrad.html | 10 +- .../_autosummary/mlx.optimizers.Adam.html | 10 +- .../_autosummary/mlx.optimizers.AdamW.html | 10 +- .../_autosummary/mlx.optimizers.Adamax.html | 10 +- .../_autosummary/mlx.optimizers.Lion.html | 10 +- ....optimizers.Optimizer.apply_gradients.html | 10 +- .../mlx.optimizers.Optimizer.init.html | 10 +- .../mlx.optimizers.Optimizer.state.html | 10 +- .../mlx.optimizers.Optimizer.update.html | 10 +- .../_autosummary/mlx.optimizers.RMSprop.html | 10 +- .../_autosummary/mlx.optimizers.SGD.html | 10 +- .../mlx.optimizers.cosine_decay.html | 10 +- .../mlx.optimizers.exponential_decay.html | 10 +- .../mlx.optimizers.join_schedules.html | 10 +- .../mlx.optimizers.linear_schedule.html | 10 +- .../mlx.optimizers.step_decay.html | 10 +- .../python/optimizers/common_optimizers.html | 10 +- .../html/python/optimizers/optimizer.html | 10 +- .../html/python/optimizers/schedulers.html | 10 +- docs/build/html/python/random.html | 10 +- docs/build/html/python/transforms.html | 10 +- docs/build/html/python/tree_utils.html | 10 +- docs/build/html/quantized_8h.html | 131 +- docs/build/html/quantized_8h_source.html | 188 +- docs/build/html/random_8h_source.html | 6 +- docs/build/html/search.html | 10 +- docs/build/html/search/all_0.js | 8 +- docs/build/html/search/all_1.js | 173 +- docs/build/html/search/all_10.js | 35 +- docs/build/html/search/all_12.js | 9 +- docs/build/html/search/all_13.js | 296 ++- docs/build/html/search/all_14.js | 109 +- docs/build/html/search/all_15.js | 21 +- docs/build/html/search/all_16.js | 22 +- docs/build/html/search/all_2.js | 101 +- docs/build/html/search/all_3.js | 237 +-- docs/build/html/search/all_4.js | 109 +- docs/build/html/search/all_5.js | 73 +- docs/build/html/search/all_7.js | 156 +- docs/build/html/search/all_8.js | 6 +- docs/build/html/search/all_9.js | 45 +- docs/build/html/search/all_b.js | 17 +- docs/build/html/search/all_c.js | 6 +- docs/build/html/search/all_d.js | 203 +- docs/build/html/search/all_e.js | 53 +- docs/build/html/search/all_f.js | 2 +- docs/build/html/search/classes_1.js | 51 +- docs/build/html/search/classes_15.js | 8 +- docs/build/html/search/classes_7.js | 8 +- docs/build/html/search/classes_a.js | 7 +- docs/build/html/search/classes_c.js | 30 +- docs/build/html/search/defines_0.js | 4 +- docs/build/html/search/defines_8.js | 4 +- docs/build/html/search/enums_2.js | 3 +- docs/build/html/search/enumvalues_3.js | 7 +- docs/build/html/search/files_6.js | 7 +- docs/build/html/search/functions_0.js | 4 +- docs/build/html/search/functions_1.js | 135 +- docs/build/html/search/functions_10.js | 6 +- docs/build/html/search/functions_12.js | 7 +- docs/build/html/search/functions_13.js | 200 +- docs/build/html/search/functions_14.js | 47 +- docs/build/html/search/functions_15.js | 11 +- docs/build/html/search/functions_16.js | 13 +- docs/build/html/search/functions_2.js | 65 +- docs/build/html/search/functions_3.js | 157 +- docs/build/html/search/functions_4.js | 62 +- docs/build/html/search/functions_5.js | 57 +- docs/build/html/search/functions_7.js | 115 +- docs/build/html/search/functions_9.js | 39 +- docs/build/html/search/functions_c.js | 6 +- docs/build/html/search/functions_d.js | 14 +- docs/build/html/search/functions_f.js | 2 +- docs/build/html/search/namespaces_0.js | 7 +- docs/build/html/search/typedefs_2.js | 5 +- docs/build/html/search/typedefs_9.js | 2 +- docs/build/html/search/variables_1.js | 16 +- docs/build/html/search/variables_13.js | 30 +- docs/build/html/search/variables_6.js | 19 +- docs/build/html/search/variables_7.js | 6 +- docs/build/html/search/variables_d.js | 11 +- docs/build/html/searchindex.js | 2 +- docs/build/html/steel__gemm__masked_8h.html | 2 +- .../html/steel__gemm__masked_8h_source.html | 20 +- docs/build/html/struct___no_mask-members.html | 4 + docs/build/html/struct___no_mask.html | 121 +- .../html/struct_g_e_m_v_kernel-members.html | 99 + docs/build/html/struct_g_e_m_v_kernel.html | 546 ++++++ .../html/struct_g_e_m_v_t_kernel-members.html | 97 + docs/build/html/struct_g_e_m_v_t_kernel.html | 471 +++++ docs/build/html/struct_scale_op-members.html | 3 +- docs/build/html/struct_scale_op.html | 36 +- ..._1_1metal_1_1_command_encoder-members.html | 8 +- ..._1_1core_1_1metal_1_1_command_encoder.html | 32 - ...structmlx_1_1steel_1_1_g_e_m_m_kernel.html | 4 +- ...structmlx_1_1steel_1_1_loop_alignment.html | 4 +- docs/build/html/types_2complex_8h_source.html | 8 +- docs/build/html/usage/compile.html | 10 +- docs/build/html/usage/distributed.html | 10 +- .../build/html/usage/function_transforms.html | 10 +- docs/build/html/usage/indexing.html | 10 +- docs/build/html/usage/lazy_evaluation.html | 10 +- docs/build/html/usage/numpy.html | 10 +- docs/build/html/usage/quick_start.html | 10 +- docs/build/html/usage/saving_and_loading.html | 10 +- docs/build/html/usage/unified_memory.html | 10 +- docs/build/html/usage/using_streams.html | 10 +- docs/build/html/utils_8h.html | 2 - docs/build/html/utils_8h_source.html | 93 +- 724 files changed, 14529 insertions(+), 11046 deletions(-) create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.linalg.cholesky_inv.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.linalg.tri_inv.rst delete mode 100644 docs/build/html/class_m_p_s_1_1_kernel-members.html delete mode 100644 docs/build/html/class_m_p_s_1_1_kernel.html delete mode 100644 docs/build/html/class_m_p_s_1_1_kernel.png delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix-members.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix.png delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_descriptor.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_descriptor.png delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_multiplication-members.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_multiplication.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_multiplication.png delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html delete mode 100644 docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.png delete mode 100644 docs/build/html/class_m_p_s_1_1_vector-members.html delete mode 100644 docs/build/html/class_m_p_s_1_1_vector.html delete mode 100644 docs/build/html/class_m_p_s_1_1_vector.png delete mode 100644 docs/build/html/class_m_p_s_1_1_vector_descriptor-members.html delete mode 100644 docs/build/html/class_m_p_s_1_1_vector_descriptor.html delete mode 100644 docs/build/html/class_m_p_s_1_1_vector_descriptor.png create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_affine_quantize-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_affine_quantize.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_affine_quantize.png rename docs/build/html/{dir_1683daa6c50d5a1449f58a10604f9f12.html => functions_vars_h.html} (64%) rename docs/build/html/{kernels_2steel_2gemm_2gemm_8h.html => gemm_8h.html} (98%) rename docs/build/html/{kernels_2steel_2gemm_2gemm_8h_source.html => gemm_8h_source.html} (99%) create mode 100644 docs/build/html/jit_2gemv__masked_8h.html create mode 100644 docs/build/html/jit_2gemv__masked_8h_source.html create mode 100644 docs/build/html/kernels_2gemv__masked_8h.html create mode 100644 docs/build/html/kernels_2gemv__masked_8h_source.html delete mode 100644 docs/build/html/mps_2gemm_8h.html delete mode 100644 docs/build/html/mps_2gemm_8h_source.html delete mode 100644 docs/build/html/namespace_m_p_s.html delete mode 100644 docs/build/html/namespace_m_t_l.html delete mode 100644 docs/build/html/namespace_m_t_l_1_1_private.html delete mode 100644 docs/build/html/namespace_m_t_l_1_1_private_1_1_class.html delete mode 100644 docs/build/html/namespace_m_t_l_1_1_private_1_1_selector.html delete mode 100644 docs/build/html/namespacemembers_a.html delete mode 100644 docs/build/html/namespacemembers_func_a.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.linalg.cholesky_inv.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.linalg.tri_inv.html create mode 100644 docs/build/html/struct_g_e_m_v_kernel-members.html create mode 100644 docs/build/html/struct_g_e_m_v_kernel.html create mode 100644 docs/build/html/struct_g_e_m_v_t_kernel-members.html create mode 100644 docs/build/html/struct_g_e_m_v_t_kernel.html diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 02b67f3136..e819c46104 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 3b02955047d4e4b232f01bccda3ed898 +config: bbbe4e54ecfcc611156ee21a8f15e97b tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_sources/dev/extensions.rst b/docs/build/html/_sources/dev/extensions.rst index 9a2be90cd4..ecb418468f 100644 --- a/docs/build/html/_sources/dev/extensions.rst +++ b/docs/build/html/_sources/dev/extensions.rst @@ -486,9 +486,8 @@ below. std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available and look for it - // in the same folder as this executable if needed - d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + // Make sure the metal library is available + d.register_library("mlx_ext"); // Make a kernel from this metal library auto kernel = d.get_kernel(kname.str(), "mlx_ext"); diff --git a/docs/build/html/_sources/examples/llama-inference.rst b/docs/build/html/_sources/examples/llama-inference.rst index 0e080146bc..7e06895e35 100644 --- a/docs/build/html/_sources/examples/llama-inference.rst +++ b/docs/build/html/_sources/examples/llama-inference.rst @@ -15,7 +15,7 @@ module to concisely define the model architecture. Attention layer ^^^^^^^^^^^^^^^^ -We will start with the llama attention layer which notably uses the RoPE +We will start with the Llama attention layer which notably uses the RoPE positional encoding. [1]_ In addition, our attention layer will optionally use a key/value cache that will be concatenated with the provided keys and values to support efficient inference. diff --git a/docs/build/html/_sources/examples/mlp.rst b/docs/build/html/_sources/examples/mlp.rst index 36890e95cf..3214af504e 100644 --- a/docs/build/html/_sources/examples/mlp.rst +++ b/docs/build/html/_sources/examples/mlp.rst @@ -64,7 +64,7 @@ set: Next, setup the problem parameters and load the data. To load the data, you need our `mnist data loader `_, which -we will import as `mnist`. +we will import as ``mnist``. .. code-block:: python diff --git a/docs/build/html/_sources/install.rst b/docs/build/html/_sources/install.rst index c2288e46df..c8cf5723bc 100644 --- a/docs/build/html/_sources/install.rst +++ b/docs/build/html/_sources/install.rst @@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from git clone git@github.com:ml-explore/mlx.git mlx && cd mlx -Install `nanobind `_ with: - -.. code-block:: shell - - pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 - Then simply build and install MLX using pip: .. code-block:: shell - env CMAKE_BUILD_PARALLEL_LEVEL="" pip install . + CMAKE_BUILD_PARALLEL_LEVEL="" pip install . -For developing use an editable install: +For developing, install the package with development dependencies, and use an +editable install: .. code-block:: shell - env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . + CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]" -To make sure the install is working run the tests with: +Once the development dependencies are installed, you can build faster with: + +.. code-block:: shell + + CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace + +Run the tests with: .. code-block:: shell - pip install ".[testing]" python -m unittest discover python/tests -Optional: Install stubs to enable auto completions and type checking from your IDE: +Optional: Install stubs to enable auto completions and type checking from your +IDE: .. code-block:: shell - pip install ".[dev]" python setup.py generate_stubs C++ API diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.DtypeCategory.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.DtypeCategory.rst index ab504dd356..22664ef416 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.core.DtypeCategory.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.DtypeCategory.rst @@ -9,14 +9,21 @@ .. automethod:: __init__ - .. rubric:: Methods + + + + + .. rubric:: Attributes .. autosummary:: - ~DtypeCategory.__init__ - - - - + ~DtypeCategory.complexfloating + ~DtypeCategory.floating + ~DtypeCategory.inexact + ~DtypeCategory.signedinteger + ~DtypeCategory.unsignedinteger + ~DtypeCategory.integer + ~DtypeCategory.number + ~DtypeCategory.generic \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.linalg.cholesky_inv.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.linalg.cholesky_inv.rst new file mode 100644 index 0000000000..8e08050f84 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.linalg.cholesky_inv.rst @@ -0,0 +1,6 @@ +mlx.core.linalg.cholesky\_inv +============================= + +.. currentmodule:: mlx.core.linalg + +.. autofunction:: cholesky_inv \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.linalg.tri_inv.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.linalg.tri_inv.rst new file mode 100644 index 0000000000..4f12fcc8ec --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.linalg.tri_inv.rst @@ -0,0 +1,6 @@ +mlx.core.linalg.tri\_inv +======================== + +.. currentmodule:: mlx.core.linalg + +.. autofunction:: tri_inv \ No newline at end of file diff --git a/docs/build/html/_sources/python/linalg.rst b/docs/build/html/_sources/python/linalg.rst index 3c34cb3f79..e7fd5ecee3 100644 --- a/docs/build/html/_sources/python/linalg.rst +++ b/docs/build/html/_sources/python/linalg.rst @@ -9,7 +9,9 @@ Linear Algebra :toctree: _autosummary inv + tri_inv norm cholesky + cholesky_inv qr svd diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index d6799b6791..df3ade8a7b 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,5 +1,5 @@ const DOCUMENTATION_OPTIONS = { - VERSION: '0.16.1', + VERSION: '0.16.2', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/annotated.html b/docs/build/html/annotated.html index a846e55595..f07ab420b5 100644 --- a/docs/build/html/annotated.html +++ b/docs/build/html/annotated.html @@ -151,13 +151,14 @@ $(function() {  CDistPrimitive  CGroupA distributed::Group represents a group of independent mlx processes that can communicate  Nfast - CCustom - CLayerNorm - CLayerNormVJP - CRMSNorm - CRMSNormVJP - CRoPE - CScaledDotProductAttention + CAffineQuantize + CCustom + CLayerNorm + CLayerNormVJP + CRMSNorm + CRMSNormVJP + CRoPE + CScaledDotProductAttention  Nio  CFileReader  CFileWriter @@ -318,151 +319,145 @@ $(function() {  CTransformAdd  CTransformAxpby  CTransformNone - NMPS - CKernel - CMatrix - CMatrixDescriptor - CMatrixMultiplication - CMatrixVectorMultiplication - CVector - CVectorDescriptor - Npocketfft - Ndetail - Nthreading - Caligned_allocator - Cconcurrent_queue - Clatch - Cthread_pool - Cadd_vec - Cadd_vec< cmplx< T > > - Carr - Carr_info - Ccfftp - Ccmplx - Ccndarr - CExecC2C - CExecDcst - CExecHartley - CExecR2R - Cfftblue - Cmulti_iter - Cndarr - Cpocketfft_c - Cpocketfft_r - Crev_iter - Crfftp - Csimple_iter - Csincos_2pibyn - CT_dcst23 - CT_dcst4 - CT_dct1 - CT_dst1 - Cutil - CVLEN - CVTYPE - C_MLX_BFloat16 - Cbits_to_bfloat_struct - C_NoMask - CAbs - CAdd - CAnd - CArcCos - CArcCosh - CArcSin - CArcSinh - CArcTan - CArcTan2 - CArcTanh - CBitwiseAnd - CBitwiseOr - CBitwiseXor - CBlockMergeSort - Cbool4_or_uint - CCeil - Ccomplex64_t - CConjugate - CCos - CCosh - CCumMax - CCumMin - CCumProd - CCumProd< bool > - CCumSum - CDivide - CDivMod - CEqual - CErf - CErfInv - CExp - CExpm1 - CFloor - CFloorDivide - CGreater - CGreaterEqual - CIndices - CKernelMergeSort - CKernelMultiBlockMergeSort - CLeftShift - CLess - CLessEqual - CLessThan - CLimits - CLimits< bfloat16_t > - CLimits< bool > - CLimits< float > - CLimits< half > - CLimits< int16_t > - CLimits< int32_t > - CLimits< int64_t > - CLimits< int8_t > - CLimits< uint16_t > - CLimits< uint32_t > - CLimits< uint64_t > - CLimits< uint8_t > - CLog - CLog10 - CLog1p - CLog2 - CLogAddExp - CLogicalAnd - CLogicalNot - CLogicalOr - CMax - CMaximum - CMin - CMinimum - Cmlx_atomic - Cmlx_atomic< T, enable_if_t< is_metal_atomic< T > > > - CMLXConvParams - CMLXFastAttentionParams - CMLXScaledDotProductAttentionParams - CMultiply - CNaNEqual - CNegative - CNone - CNotEqual - COr - CPower - CProd - CQuantizedBlockLoader - CReadWriter - CRemainder - CRightShift - CRound - CRsqrt - CScaleOp - CSelect - CSigmoid - CSign - CSin - CSinh - CSqrt - CSquare - CSubtract - CSum - CTan - CTanh - CThreadSort + Npocketfft + Ndetail + Nthreading + Caligned_allocator + Cconcurrent_queue + Clatch + Cthread_pool + Cadd_vec + Cadd_vec< cmplx< T > > + Carr + Carr_info + Ccfftp + Ccmplx + Ccndarr + CExecC2C + CExecDcst + CExecHartley + CExecR2R + Cfftblue + Cmulti_iter + Cndarr + Cpocketfft_c + Cpocketfft_r + Crev_iter + Crfftp + Csimple_iter + Csincos_2pibyn + CT_dcst23 + CT_dcst4 + CT_dct1 + CT_dst1 + Cutil + CVLEN + CVTYPE + C_MLX_BFloat16 + Cbits_to_bfloat_struct + C_NoMask + CAbs + CAdd + CAnd + CArcCos + CArcCosh + CArcSin + CArcSinh + CArcTan + CArcTan2 + CArcTanh + CBitwiseAnd + CBitwiseOr + CBitwiseXor + CBlockMergeSort + Cbool4_or_uint + CCeil + Ccomplex64_t + CConjugate + CCos + CCosh + CCumMax + CCumMin + CCumProd + CCumProd< bool > + CCumSum + CDivide + CDivMod + CEqual + CErf + CErfInv + CExp + CExpm1 + CFloor + CFloorDivide + CGEMVKernel + CGEMVTKernelVector matrix multiplication + CGreater + CGreaterEqual + CIndices + CKernelMergeSort + CKernelMultiBlockMergeSort + CLeftShift + CLess + CLessEqual + CLessThan + CLimits + CLimits< bfloat16_t > + CLimits< bool > + CLimits< float > + CLimits< half > + CLimits< int16_t > + CLimits< int32_t > + CLimits< int64_t > + CLimits< int8_t > + CLimits< uint16_t > + CLimits< uint32_t > + CLimits< uint64_t > + CLimits< uint8_t > + CLog + CLog10 + CLog1p + CLog2 + CLogAddExp + CLogicalAnd + CLogicalNot + CLogicalOr + CMax + CMaximum + CMin + CMinimum + Cmlx_atomic + Cmlx_atomic< T, enable_if_t< is_metal_atomic< T > > > + CMLXConvParams + CMLXFastAttentionParams + CMLXScaledDotProductAttentionParams + CMultiply + CNaNEqual + CNegative + CNone + CNotEqual + COr + CPower + CProd + CQuantizedBlockLoader + CReadWriter + CRemainder + CRightShift + CRound + CRsqrt + CScaleOp + CSelect + CSigmoid + CSign + CSin + CSinh + CSqrt + CSquare + CSubtract + CSum + CTan + CTanh + CThreadSort diff --git a/docs/build/html/array_8h_source.html b/docs/build/html/array_8h_source.html index 68b0552273..ceb409e9c6 100644 --- a/docs/build/html/array_8h_source.html +++ b/docs/build/html/array_8h_source.html @@ -794,25 +794,25 @@ $(function() { codefold.init(0); });
void free(Buffer buffer)
Definition allocator.h:7
constexpr bool is_array_v
Definition array.h:559
-
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype bool_
Definition dtype.h:58
std::function< void(allocator::Buffer)> deleter_t
Definition array.h:18
-
constexpr Dtype uint64
Definition dtype.h:65
-
constexpr Dtype uint16
Definition dtype.h:63
-
constexpr Dtype bfloat16
Definition dtype.h:74
-
constexpr Dtype int32
Definition dtype.h:69
-
constexpr Dtype float32
Definition dtype.h:73
-
constexpr Dtype int16
Definition dtype.h:68
-
constexpr Dtype int8
Definition dtype.h:67
-
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint64
Definition dtype.h:63
+
constexpr Dtype uint16
Definition dtype.h:61
+
constexpr Dtype bfloat16
Definition dtype.h:72
+
constexpr Dtype int32
Definition dtype.h:67
+
constexpr Dtype float32
Definition dtype.h:71
+
constexpr Dtype int16
Definition dtype.h:66
+
constexpr Dtype int8
Definition dtype.h:65
+
constexpr Dtype int64
Definition dtype.h:68
constexpr bool is_arrays_v
Definition array.h:563
-
constexpr Dtype uint8
Definition dtype.h:62
-
constexpr Dtype float16
Definition dtype.h:72
-
constexpr Dtype uint32
Definition dtype.h:64
-
uint8_t size_of(const Dtype &t)
Definition dtype.h:95
+
constexpr Dtype uint8
Definition dtype.h:60
+
constexpr Dtype float16
Definition dtype.h:70
+
constexpr Dtype uint32
Definition dtype.h:62
+
uint8_t size_of(const Dtype &t)
Definition dtype.h:93
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:566
-
constexpr Dtype complex64
Definition dtype.h:75
-
Definition dtype.h:15
-
Definition dtype.h:102
+
constexpr Dtype complex64
Definition dtype.h:73
+
Definition dtype.h:13
+
Definition dtype.h:100
Definition array.h:141
friend bool operator==(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:161
diff --git a/docs/build/html/backend_2accelerate_2utils_8h_source.html b/docs/build/html/backend_2accelerate_2utils_8h_source.html index b8057101e7..f5ce8d3f8f 100644 --- a/docs/build/html/backend_2accelerate_2utils_8h_source.html +++ b/docs/build/html/backend_2accelerate_2utils_8h_source.html @@ -117,8 +117,8 @@ $(function() { codefold.init(0); });
Definition allocator.h:7
BNNSDataType to_bnns_dtype(Dtype mlx_dtype)
Definition utils.h:10
Dtype::Kind kindof(const Dtype &t)
-
uint8_t size_of(const Dtype &t)
Definition dtype.h:95
-
Definition dtype.h:15
+
uint8_t size_of(const Dtype &t)
Definition dtype.h:93
+
Definition dtype.h:13
diff --git a/docs/build/html/backend_2metal_2device_8h.html b/docs/build/html/backend_2metal_2device_8h.html index 573f61f73a..9a49c74a9a 100644 --- a/docs/build/html/backend_2metal_2device_8h.html +++ b/docs/build/html/backend_2metal_2device_8h.html @@ -89,8 +89,6 @@ $(function() { #include <string>
#include <unordered_map>
#include <unordered_set>
-#include <dlfcn.h>
-#include <filesystem>
#include "mlx/array.h"
#include "mlx/device.h"
@@ -121,8 +119,6 @@ Typedefs - -

Functions

std::string mlx::core::metal::get_colocated_mtllib_path (const std::string &lib_name)
 
Devicemlx::core::metal::device (mlx::core::Device)
 
diff --git a/docs/build/html/backend_2metal_2device_8h_source.html b/docs/build/html/backend_2metal_2device_8h_source.html index 8a7e381018..13ee2fe02d 100644 --- a/docs/build/html/backend_2metal_2device_8h_source.html +++ b/docs/build/html/backend_2metal_2device_8h_source.html @@ -94,258 +94,190 @@ $(function() { codefold.init(0); });
9#include <unordered_map>
10#include <unordered_set>
11
-
12#include <dlfcn.h>
-
13#include <filesystem>
+
12#include "mlx/array.h"
+
13#include "mlx/device.h"
14
-
15#include "mlx/array.h"
-
16#include "mlx/device.h"
-
17
-
18namespace fs = std::filesystem;
+
15namespace mlx::core::metal {
+
16
+
17using MTLFCList =
+
18 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
19
-
20namespace mlx::core::metal {
-
21
-
-
22inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
-
23 Dl_info info;
-
24 std::string mtllib_path;
-
25 std::string lib_ext = lib_name + ".metallib";
-
26
-
27 int success = dladdr((void*)get_colocated_mtllib_path, &info);
-
28 if (success) {
-
29 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
-
30 mtllib_path = mtllib.c_str();
-
31 }
-
32
-
33 return mtllib_path;
-
34}
+
+ +
21 CommandEncoder(MTL::CommandBuffer* cbuf);
+ + +
24
+
+ +
+ +
27 enc.concurrent = true;
+
28 }
+
+
+ +
30 enc.concurrent = false;
+
31 enc.outputs.insert(
+
32 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
+
33 enc.concurrent_outputs.clear();
+
34 }
35
-
36using MTLFCList =
-
37 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
-
38
-
- +
36 private:
+
37 CommandEncoder& enc;
+
38 };
+
+
39
-
40 CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
-
41 enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
-
42 enc->retain();
-
43 };
+
40 MTL::ComputeCommandEncoder* operator->() {
+
41 return enc;
+
42 }
- - -
46
-
- -
- -
49 enc.concurrent = true;
-
50 }
-
-
- -
52 enc.concurrent = false;
-
53 enc.outputs.insert(
-
54 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
-
55 enc.concurrent_outputs.clear();
-
56 }
+
43
+
44 void set_input_array(const array& a, int idx, int64_t offset = 0);
+
45 void set_output_array(array& a, int idx, int64_t offset = 0);
+
46 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
+
47 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
+
48
+
+ +
50 return ConcurrentContext(*this);
+
51 }
+
52
+ +
54
+
55 private:
+
56 void maybe_split();
57
-
58 private:
-
59 CommandEncoder& enc;
-
60 };
-
-
61
-
-
62 MTL::ComputeCommandEncoder* operator->() {
-
63 return enc;
-
64 }
+
58 int num_dispatches{0};
+
59 MTL::CommandBuffer* cbuf;
+
60 MTL::ComputeCommandEncoder* enc;
+
61 bool concurrent{false};
+
62 std::unordered_set<MTL::Resource*> outputs;
+
63 std::unordered_set<MTL::Resource*> concurrent_outputs;
+
64};
65
-
-
66 void set_input_array(const array& a, int idx, int64_t offset = 0) {
-
67 auto r_buf =
-
68 static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
-
69 if (auto it = outputs.find(r_buf); it != outputs.end()) {
-
70 // Insert a barrier
-
71 enc->memoryBarrier(&r_buf, 1);
+
+
66class Device {
+
67 public:
+ +
69 Device(const Device&) = delete;
+
70 Device& operator=(const Device&) = delete;
+
72
-
73 // Remove the output
-
74 outputs.erase(it);
-
75 }
-
76 auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
-
77 auto base_offset = a.data<char>() -
-
78 static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
-
79 base_offset += offset;
-
80 enc->setBuffer(a_buf, base_offset, idx);
-
81 }
+
+
73 MTL::Device* mtl_device() {
+
74 return device_;
+
75 };
-
82
-
-
83 void set_output_array(array& a, int idx, int64_t offset = 0) {
-
84 // Add barriers before adding the output to the output set
-
85 set_input_array(a, idx, offset);
-
86 auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
-
87 if (concurrent) {
-
88 concurrent_outputs.insert(buf);
-
89 } else {
-
90 outputs.insert(buf);
-
91 }
-
92 }
-
-
93
-
94 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
-
95 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
-
96
-
- -
98 return ConcurrentContext(*this);
-
99 }
-
-
100
-
- -
102 enc->endEncoding();
-
103 enc->release();
-
104 }
-
-
105
-
106 private:
-
107 void maybe_split();
+
76
+
77 void new_queue(int index);
+
78 MTL::CommandBuffer* get_command_buffer(int index);
+
79 int get_command_buffer_ops(int index);
+ +
81 void commit_command_buffer(int index);
+ +
83 void end_encoding(int index);
+
84
+ +
86 const std::string& lib_name,
+
87 const std::string& lib_path);
+
88
+
89 void register_library(const std::string& lib_name);
+
90
+
91 MTL::Library* get_library(const std::string& name);
+
92
+
93 MTL::Library* get_library(
+
94 const std::string& name,
+
95 const std::string& source_string,
+
96 bool cache = true);
+
97
+
98 MTL::Library* get_library(
+
99 const std::string& name,
+
100 const MTL::StitchedLibraryDescriptor* desc,
+
101 bool cache = true);
+
102
+
103 MTL::Function* get_function(
+
104 const std::string& base_name,
+
105 MTL::Library* mtl_lib,
+
106 const std::string& specialized_name = "",
+
107 const MTLFCList& func_consts = {});
108
-
109 int num_dispatches{0};
-
110 MTL::CommandBuffer* cbuf;
-
111 MTL::ComputeCommandEncoder* enc;
-
112 bool concurrent{false};
-
113 std::unordered_set<MTL::Resource*> outputs;
-
114 std::unordered_set<MTL::Resource*> concurrent_outputs;
-
115};
-
-
116
-
-
117class Device {
-
118 public:
- -
120 Device(const Device&) = delete;
-
121 Device& operator=(const Device&) = delete;
- -
123
-
-
124 MTL::Device* mtl_device() {
-
125 return device_;
-
126 };
-
-
127
-
128 void new_queue(int index);
-
129 MTL::CommandBuffer* get_command_buffer(int index);
- - -
132 void commit_command_buffer(int index);
- -
134 void end_encoding(int index);
-
135
- -
137 const std::string& lib_name,
-
138 const std::string& lib_path);
- -
140 const std::string& lib_name,
-
141 const std::function<std::string(const std::string&)>& lib_path_func =
- -
143
-
144 MTL::Library* get_library(const std::string& name);
+
109 MTL::Function* get_function(
+
110 const std::string& base_name,
+
111 const std::string& lib_name = "mlx",
+
112 const std::string& specialized_name = "",
+
113 const MTLFCList& func_consts = {});
+
114
+
115 MTL::ComputePipelineState* get_kernel(
+
116 const std::string& base_name,
+
117 MTL::Library* mtl_lib,
+
118 const std::string& hash_name = "",
+
119 const MTLFCList& func_consts = {},
+
120 const std::vector<MTL::Function*>& linked_functions = {});
+
121
+
122 MTL::ComputePipelineState* get_kernel(
+
123 const std::string& base_name,
+
124 const std::string& lib_name = "mlx",
+
125 const std::string& hash_name = "",
+
126 const MTLFCList& func_consts = {},
+
127 const std::vector<MTL::Function*>& linked_functions = {});
+
128
+
129 MTL::ArgumentEncoder* argument_encoder(
+
130 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
+
131
+
132 private:
+
133 MTL::Library* get_library_cache_(const std::string& name);
+
134
+
135 MTL::Library* get_library_(const std::string& source_string);
+
136 MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
+
137
+
138 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
+
139
+
140 MTL::Function* get_function_(
+
141 const std::string& name,
+
142 const std::string& specialized_name,
+
143 const MTLFCList& func_consts,
+
144 MTL::Library* mtl_lib);
145
-
146 MTL::Library* get_library(
-
147 const std::string& name,
-
148 const std::string& source_string,
-
149 bool cache = true);
-
150
-
151 MTL::Library* get_library(
-
152 const std::string& name,
-
153 const MTL::StitchedLibraryDescriptor* desc,
-
154 bool cache = true);
-
155
-
156 MTL::Function* get_function(
-
157 const std::string& base_name,
-
158 MTL::Library* mtl_lib,
-
159 const std::string& specialized_name = "",
-
160 const MTLFCList& func_consts = {});
-
161
-
162 MTL::Function* get_function(
-
163 const std::string& base_name,
-
164 const std::string& lib_name = "mlx",
-
165 const std::string& specialized_name = "",
-
166 const MTLFCList& func_consts = {});
-
167
-
168 MTL::ComputePipelineState* get_kernel(
-
169 const std::string& base_name,
-
170 MTL::Library* mtl_lib,
-
171 const std::string& hash_name = "",
-
172 const MTLFCList& func_consts = {},
-
173 const std::vector<MTL::Function*>& linked_functions = {});
-
174
-
175 MTL::ComputePipelineState* get_kernel(
-
176 const std::string& base_name,
-
177 const std::string& lib_name = "mlx",
-
178 const std::string& hash_name = "",
-
179 const MTLFCList& func_consts = {},
-
180 const std::vector<MTL::Function*>& linked_functions = {});
-
181
-
182 MTL::ArgumentEncoder* argument_encoder(
-
183 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
-
184
-
185 private:
-
186 MTL::Library* get_library_cache_(const std::string& name);
-
187
-
188 MTL::Library* get_library_(const std::string& source_string);
-
189 MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
-
190
-
191 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
-
192
-
193 MTL::Function* get_function_(
-
194 const std::string& name,
-
195 const std::string& specialized_name,
-
196 const MTLFCList& func_consts,
-
197 MTL::Library* mtl_lib);
-
198
-
199 MTL::LinkedFunctions* get_linked_functions_(
-
200 const std::vector<MTL::Function*>& funcs);
-
201
-
202 MTL::ComputePipelineState* get_kernel_(
-
203 const std::string& name,
-
204 const MTL::Function* mtl_function);
-
205
-
206 MTL::ComputePipelineState* get_kernel_(
-
207 const std::string& name,
-
208 const MTL::Function* mtl_function,
-
209 const MTL::LinkedFunctions* linked_functions);
-
210
-
211 MTL::Device* device_;
-
212 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
-
213 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
-
214 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
-
215 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
-
216 std::unordered_map<std::string, MTL::Library*> library_map_;
-
217 std::mutex mtx_;
-
218};
+
146 MTL::LinkedFunctions* get_linked_functions_(
+
147 const std::vector<MTL::Function*>& funcs);
+
148
+
149 MTL::ComputePipelineState* get_kernel_(
+
150 const std::string& name,
+
151 const MTL::Function* mtl_function);
+
152
+
153 MTL::ComputePipelineState* get_kernel_(
+
154 const std::string& name,
+
155 const MTL::Function* mtl_function,
+
156 const MTL::LinkedFunctions* linked_functions);
+
157
+
158 MTL::Device* device_;
+
159 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
+
160 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
+
161 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
+
162 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
+
163 std::unordered_map<std::string, MTL::Library*> library_map_;
+
164 std::mutex mtx_;
+
165};
-
219
- -
221
-
222} // namespace mlx::core::metal
+
166
+ +
168
+
169} // namespace mlx::core::metal
-
MTL::Buffer * buf
Definition allocator.h:38
-
const void * ptr() const
Definition allocator.h:23
Definition array.h:20
-
T * data()
Definition array.h:313
-
allocator::Buffer & buffer()
Definition array.h:299
-
Definition device.h:117
+
Definition device.h:66
int get_command_buffer_ops(int index)
-
MTL::Device * mtl_device()
Definition device.h:124
+
MTL::Device * mtl_device()
Definition device.h:73
void register_library(const std::string &lib_name, const std::string &lib_path)
MTL::CommandBuffer * get_command_buffer(int index)
void end_encoding(int index)
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
-
void register_library(const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
void increment_command_buffer_ops(int index)
void new_queue(int index)
@@ -353,6 +285,7 @@ $(function() { codefold.init(0); });
MTL::Library * get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
void commit_command_buffer(int index)
MTL::Library * get_library(const std::string &name, const std::string &source_string, bool cache=true)
+
void register_library(const std::string &lib_name)
MTL::Function * get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
Device(const Device &)=delete
MTL::Function * get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
@@ -362,23 +295,22 @@ $(function() { codefold.init(0); });
CommandEncoder & get_command_encoder(int index)
Definition allocator.h:12
-
std::string get_colocated_mtllib_path(const std::string &lib_name)
Definition device.h:22
-
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:36
+
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:17
Device & device(mlx::core::Device)
Definition device.h:7
- - -
ConcurrentContext(CommandEncoder &enc)
Definition device.h:48
-
Definition device.h:39
+ + +
ConcurrentContext(CommandEncoder &enc)
Definition device.h:26
+
Definition device.h:20
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
-
CommandEncoder(MTL::CommandBuffer *cbuf)
Definition device.h:40
+
CommandEncoder(MTL::CommandBuffer *cbuf)
CommandEncoder & operator=(const CommandEncoder &)=delete
-
ConcurrentContext start_concurrent()
Definition device.h:97
-
void set_output_array(array &a, int idx, int64_t offset=0)
Definition device.h:83
+
ConcurrentContext start_concurrent()
Definition device.h:49
+
void set_output_array(array &a, int idx, int64_t offset=0)
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
-
~CommandEncoder()
Definition device.h:101
-
MTL::ComputeCommandEncoder * operator->()
Definition device.h:62
-
void set_input_array(const array &a, int idx, int64_t offset=0)
Definition device.h:66
+ +
MTL::ComputeCommandEncoder * operator->()
Definition device.h:40
+
void set_input_array(const array &a, int idx, int64_t offset=0)
CommandEncoder(const CommandEncoder &)=delete
diff --git a/docs/build/html/backend_2metal_2utils_8h.html b/docs/build/html/backend_2metal_2utils_8h.html index 9a2da99d38..b01d6d0d53 100644 --- a/docs/build/html/backend_2metal_2utils_8h.html +++ b/docs/build/html/backend_2metal_2utils_8h.html @@ -76,7 +76,8 @@ $(function() {
utils.h File Reference
@@ -92,6 +93,29 @@ Namespaces   namespace  mlx::core   + + + + + + + + + + + + + + + + + + + + + +

+Functions

template<typename T >
void mlx::core::set_vector_bytes (CommandEncoder &enc, const std::vector< T > &vec, size_t nelems, int idx)
 
template<typename T >
void mlx::core::set_vector_bytes (CommandEncoder &enc, const std::vector< T > &vec, int idx)
 
std::string mlx::core::type_to_name (const array &a)
 
MTL::Size mlx::core::get_block_dims (int dim0, int dim1, int dim2)
 
MTL::Size mlx::core::get_2d_grid_dims (const std::vector< int > &shape, const std::vector< size_t > &strides)
 
NS::String * mlx::core::make_string (std::ostringstream &os)
 
void mlx::core::debug_set_stream_queue_label (MTL::CommandQueue *queue, int index)
 
void mlx::core::debug_set_primitive_buffer_label (MTL::CommandBuffer *command_buffer, Primitive &primitive)
 
std::string mlx::core::get_primitive_string (Primitive *primitive)
 
diff --git a/docs/build/html/backend_2metal_2utils_8h_source.html b/docs/build/html/backend_2metal_2utils_8h_source.html index 7e3a1a80f4..015caa2fa0 100644 --- a/docs/build/html/backend_2metal_2utils_8h_source.html +++ b/docs/build/html/backend_2metal_2utils_8h_source.html @@ -93,155 +93,95 @@ $(function() { codefold.init(0); });
8
9namespace mlx::core {
10
-
11namespace {
+
11using metal::CommandEncoder;
12
-
13using metal::CommandEncoder;
-
14
-
15template <typename T>
-
16inline void set_vector_bytes(
-
17 CommandEncoder& enc,
-
18 const std::vector<T>& vec,
-
19 size_t nelems,
-
20 int idx) {
-
21 enc->setBytes(vec.data(), nelems * sizeof(T), idx);
-
22}
-
23
-
24template <typename T>
-
25inline void
-
26set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
-
27 return set_vector_bytes(enc, vec, vec.size(), idx);
-
28}
+
13template <typename T>
+
+
14inline void set_vector_bytes(
+
15 CommandEncoder& enc,
+
16 const std::vector<T>& vec,
+
17 size_t nelems,
+
18 int idx) {
+
19 enc->setBytes(vec.data(), nelems * sizeof(T), idx);
+
20}
+
+
21
+
22template <typename T>
+
23inline void
+
+
24set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
+
25 return set_vector_bytes(enc, vec, vec.size(), idx);
+
26}
+
+
27
+
28std::string type_to_name(const array& a);
29
-
30std::string type_to_name(const array& a) {
-
31 std::string tname;
-
32 switch (a.dtype()) {
-
33 case bool_:
-
34 tname = "bool_";
-
35 break;
-
36 case uint8:
-
37 tname = "uint8";
-
38 break;
-
39 case uint16:
-
40 tname = "uint16";
-
41 break;
-
42 case uint32:
-
43 tname = "uint32";
-
44 break;
-
45 case uint64:
-
46 tname = "uint64";
-
47 break;
-
48 case int8:
-
49 tname = "int8";
-
50 break;
-
51 case int16:
-
52 tname = "int16";
-
53 break;
-
54 case int32:
-
55 tname = "int32";
-
56 break;
-
57 case int64:
-
58 tname = "int64";
-
59 break;
-
60 case float16:
-
61 tname = "float16";
-
62 break;
-
63 case float32:
-
64 tname = "float32";
-
65 break;
-
66 case bfloat16:
-
67 tname = "bfloat16";
-
68 break;
-
69 case complex64:
-
70 tname = "complex64";
-
71 break;
-
72 }
-
73 return tname;
-
74}
-
75
-
76MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
-
77 int pows[3] = {0, 0, 0};
-
78 int sum = 0;
-
79 while (true) {
-
80 int presum = sum;
-
81 // Check all the pows
-
82 if (dim0 >= (1 << (pows[0] + 1))) {
-
83 pows[0]++;
-
84 sum++;
-
85 }
-
86 if (sum == 10) {
-
87 break;
-
88 }
-
89 if (dim1 >= (1 << (pows[1] + 1))) {
-
90 pows[1]++;
-
91 sum++;
-
92 }
-
93 if (sum == 10) {
-
94 break;
-
95 }
-
96 if (dim2 >= (1 << (pows[2] + 1))) {
-
97 pows[2]++;
-
98 sum++;
-
99 }
-
100 if (sum == presum || sum == 10) {
-
101 break;
-
102 }
-
103 }
-
104 return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
-
105}
-
106
-
107inline NS::String* make_string(std::ostringstream& os) {
-
108 std::string string = os.str();
-
109 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
-
110}
-
111
-
112inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
-
113#ifdef MLX_METAL_DEBUG
-
114 std::ostringstream label;
-
115 label << "Stream " << index;
-
116 queue->setLabel(make_string(label));
-
117#endif
-
118}
-
119
-
120inline void debug_set_primitive_buffer_label(
-
121 MTL::CommandBuffer* command_buffer,
-
122 Primitive& primitive) {
-
123#ifdef MLX_METAL_DEBUG
-
124 std::ostringstream label;
-
125 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
-
126 label << cbuf_label->utf8String();
-
127 }
-
128 primitive.print(label);
-
129 command_buffer->setLabel(make_string(label));
-
130#endif
-
131}
-
132
-
133std::string get_primitive_string(Primitive* primitive) {
-
134 std::ostringstream op_t;
-
135 primitive->print(op_t);
-
136 return op_t.str();
-
137}
-
138
-
139} // namespace
-
140
-
141} // namespace mlx::core
+
30// Compute the thread block dimensions which fit the given
+
31// input dimensions.
+
32// - The thread block dimensions will be powers of two
+
33// - The thread block size will be less than 1024
+
34MTL::Size get_block_dims(int dim0, int dim1, int dim2);
+
35
+
36// Computes a 2D grid where each element is < UINT_MAX
+
37// Assumes:
+
38// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
+
39// - shape and strides correspond to a contiguous (no holes) but
+
40// possibly broadcasted array
+ +
42 const std::vector<int>& shape,
+
43 const std::vector<size_t>& strides);
+
44
+
+
45inline NS::String* make_string(std::ostringstream& os) {
+
46 std::string string = os.str();
+
47 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
+
48}
+
+
49
+
+
50inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
+
51#ifdef MLX_METAL_DEBUG
+
52 std::ostringstream label;
+
53 label << "Stream " << index;
+
54 queue->setLabel(make_string(label));
+
55#endif
+
56}
+
+
57
+
+ +
59 MTL::CommandBuffer* command_buffer,
+
60 Primitive& primitive) {
+
61#ifdef MLX_METAL_DEBUG
+
62 std::ostringstream label;
+
63 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
+
64 label << cbuf_label->utf8String();
+
65 }
+
66 primitive.print(label);
+
67 command_buffer->setLabel(make_string(label));
+
68#endif
+
69}
+
+
70
+
71std::string get_primitive_string(Primitive* primitive);
+
72
+
73} // namespace mlx::core
-
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
+
Definition primitives.h:48
+
virtual void print(std::ostream &os)=0
Print the primitive.
+
Definition array.h:20
Definition allocator.h:7
-
constexpr Dtype bool_
Definition dtype.h:60
-
constexpr Dtype uint64
Definition dtype.h:65
-
constexpr Dtype uint16
Definition dtype.h:63
-
constexpr Dtype bfloat16
Definition dtype.h:74
-
constexpr Dtype int32
Definition dtype.h:69
-
constexpr Dtype float32
Definition dtype.h:73
-
constexpr Dtype int16
Definition dtype.h:68
-
constexpr Dtype int8
Definition dtype.h:67
-
constexpr Dtype int64
Definition dtype.h:70
-
constexpr Dtype uint8
Definition dtype.h:62
-
constexpr Dtype float16
Definition dtype.h:72
-
constexpr Dtype uint32
Definition dtype.h:64
-
constexpr Dtype complex64
Definition dtype.h:75
+
void debug_set_primitive_buffer_label(MTL::CommandBuffer *command_buffer, Primitive &primitive)
Definition utils.h:58
+
void set_vector_bytes(CommandEncoder &enc, const std::vector< T > &vec, size_t nelems, int idx)
Definition utils.h:14
+
void debug_set_stream_queue_label(MTL::CommandQueue *queue, int index)
Definition utils.h:50
+
MTL::Size get_block_dims(int dim0, int dim1, int dim2)
+
MTL::Size get_2d_grid_dims(const std::vector< int > &shape, const std::vector< size_t > &strides)
+
std::string get_primitive_string(Primitive *primitive)
+
NS::String * make_string(std::ostringstream &os)
Definition utils.h:45
+
std::string type_to_name(const array &a)
+
Definition device.h:20