mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
docs update
This commit is contained in:

committed by
CircleCI Docs

parent
ba4eff9520
commit
0e688cbd0f
657
docs/build/html/namespacemlx_1_1core.html
vendored
657
docs/build/html/namespacemlx_1_1core.html
vendored
@@ -177,7 +177,11 @@ Classes</h2></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_cosh.html">Cosh</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_custom_v_j_p.html">CustomVJP</a></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_custom_transforms.html">CustomTransforms</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structmlx_1_1core_1_1_default_contiguous_reduce.html">DefaultContiguousReduce</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structmlx_1_1core_1_1_default_strided_reduce.html">DefaultStridedReduce</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_depends.html">Depends</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@@ -217,6 +221,8 @@ Classes</h2></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_greater_equal.html">GreaterEqual</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_hadamard.html">Hadamard</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_inverse.html">Inverse</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classmlx_1_1core_1_1_less.html">Less</a></td></tr>
|
||||
@@ -418,6 +424,22 @@ Functions</h2></td></tr>
|
||||
<tr class="memitem:aad636e2d0b2f882cadd1b438f4daa9ed" id="r_aad636e2d0b2f882cadd1b438f4daa9ed"><td class="memTemplParams" colspan="2">template<typename stride_t > </td></tr>
|
||||
<tr class="memitem:aad636e2d0b2f882cadd1b438f4daa9ed"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aad636e2d0b2f882cadd1b438f4daa9ed">copy_inplace</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &src, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, <a class="el" href="#abd84ff6c5245e4e170b2ef5247594337">CopyType</a> ctype)</td></tr>
|
||||
<tr class="separator:aad636e2d0b2f882cadd1b438f4daa9ed"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a50214cf406957fab27c8bef32046f030" id="r_a50214cf406957fab27c8bef32046f030"><td class="memItemLeft" align="right" valign="top">const std::map< int, std::string_view > </td><td class="memItemRight" valign="bottom"><a class="el" href="#a50214cf406957fab27c8bef32046f030">hadamard_matrices</a> ()</td></tr>
|
||||
<tr class="separator:a50214cf406957fab27c8bef32046f030"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3a8fe7ba84714dbb5fdc81e93a07abc8" id="r_a3a8fe7ba84714dbb5fdc81e93a07abc8"><td class="memItemLeft" align="right" valign="top">std::pair< int, int > </td><td class="memItemRight" valign="bottom"><a class="el" href="#a3a8fe7ba84714dbb5fdc81e93a07abc8">decompose_hadamard</a> (int n)</td></tr>
|
||||
<tr class="separator:a3a8fe7ba84714dbb5fdc81e93a07abc8"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a1555dc378c5254e79199421761f26f2b" id="r_a1555dc378c5254e79199421761f26f2b"><td class="memItemLeft" align="right" valign="top"><a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="#a1555dc378c5254e79199421761f26f2b">get_reduction_plan</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &x, const std::vector< int > axes)</td></tr>
|
||||
<tr class="separator:a1555dc378c5254e79199421761f26f2b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a9a9254ce9975ec247a2718bc02d6f201" id="r_a9a9254ce9975ec247a2718bc02d6f201"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#a9a9254ce9975ec247a2718bc02d6f201">nd_loop</a> (std::function< void(int)> callback, const std::vector< int > &shape, const std::vector< size_t > &strides)</td></tr>
|
||||
<tr class="separator:a9a9254ce9975ec247a2718bc02d6f201"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a44c3ea6db6553c3f6552b9ba64a69494" id="r_a44c3ea6db6553c3f6552b9ba64a69494"><td class="memItemLeft" align="right" valign="top">std::pair< std::vector< int >, std::vector< size_t > > </td><td class="memItemRight" valign="bottom"><a class="el" href="#a44c3ea6db6553c3f6552b9ba64a69494">shapes_without_reduction_axes</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &x, const std::vector< int > &axes)</td></tr>
|
||||
<tr class="separator:a44c3ea6db6553c3f6552b9ba64a69494"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aa08ffc1e8f2c58afb2d463496f827ef0" id="r_aa08ffc1e8f2c58afb2d463496f827ef0"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename OpS , typename OpC , typename Op > </td></tr>
|
||||
<tr class="memitem:aa08ffc1e8f2c58afb2d463496f827ef0"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aa08ffc1e8f2c58afb2d463496f827ef0">reduction_op</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &x, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::vector< int > &axes, U init, OpS ops, OpC opc, Op <a class="el" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>)</td></tr>
|
||||
<tr class="separator:aa08ffc1e8f2c58afb2d463496f827ef0"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a064d61b6ddc9e5d1e261a7e33de71083" id="r_a064d61b6ddc9e5d1e261a7e33de71083"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op > </td></tr>
|
||||
<tr class="memitem:a064d61b6ddc9e5d1e261a7e33de71083"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a064d61b6ddc9e5d1e261a7e33de71083">reduction_op</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &x, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::vector< int > &axes, U init, Op <a class="el" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>)</td></tr>
|
||||
<tr class="separator:a064d61b6ddc9e5d1e261a7e33de71083"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a07ae007c87cf2d723a0fa7b25a2b6a93" id="r_a07ae007c87cf2d723a0fa7b25a2b6a93"><td class="memItemLeft" align="right" valign="top">std::tuple< bool, int64_t, std::vector< int64_t > > </td><td class="memItemRight" valign="bottom"><a class="el" href="#a07ae007c87cf2d723a0fa7b25a2b6a93">prepare_slice</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, std::vector< int > &start_indices, std::vector< int > &strides)</td></tr>
|
||||
<tr class="separator:a07ae007c87cf2d723a0fa7b25a2b6a93"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a427f2c255dfc6e1f83f97587b08e71bc" id="r_a427f2c255dfc6e1f83f97587b08e71bc"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#a427f2c255dfc6e1f83f97587b08e71bc">shared_buffer_slice</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const std::vector< size_t > &out_strides, size_t data_offset, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
@@ -427,6 +449,9 @@ Functions</h2></td></tr>
|
||||
<tr class="separator:a4950c3248e70280b406a4f1430a85880"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ad7e4f40eb351b554bbfabb6d7d600d06" id="r_ad7e4f40eb351b554bbfabb6d7d600d06"><td class="memItemLeft" align="right" valign="top">size_t </td><td class="memItemRight" valign="bottom"><a class="el" href="#ad7e4f40eb351b554bbfabb6d7d600d06">elem_to_loc</a> (int elem, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &a)</td></tr>
|
||||
<tr class="separator:ad7e4f40eb351b554bbfabb6d7d600d06"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ac9fb1286a1a00395e901dbff80560895" id="r_ac9fb1286a1a00395e901dbff80560895"><td class="memTemplParams" colspan="2">template<typename stride_t > </td></tr>
|
||||
<tr class="memitem:ac9fb1286a1a00395e901dbff80560895"><td class="memTemplItemLeft" align="right" valign="top">std::vector< stride_t > </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ac9fb1286a1a00395e901dbff80560895">make_contiguous_strides</a> (const std::vector< int > &shape)</td></tr>
|
||||
<tr class="separator:ac9fb1286a1a00395e901dbff80560895"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a9d151ba3e138be1954d2f51f85806b0c" id="r_a9d151ba3e138be1954d2f51f85806b0c"><td class="memTemplParams" colspan="2">template<typename stride_t > </td></tr>
|
||||
<tr class="memitem:a9d151ba3e138be1954d2f51f85806b0c"><td class="memTemplItemLeft" align="right" valign="top">std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a9d151ba3e138be1954d2f51f85806b0c">collapse_contiguous_dims</a> (const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)</td></tr>
|
||||
<tr class="separator:a9d151ba3e138be1954d2f51f85806b0c"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@@ -459,14 +484,14 @@ Functions</h2></td></tr>
|
||||
<tr class="separator:ae55b801b09ccf55cba96278163a9b1ef"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a76f614e9956a6ca05a9be4db5a483446" id="r_a76f614e9956a6ca05a9be4db5a483446"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a76f614e9956a6ca05a9be4db5a483446">get_arange_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a76f614e9956a6ca05a9be4db5a483446"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:accf153854ef650d6a6633775d8a70612" id="r_accf153854ef650d6a6633775d8a70612"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#accf153854ef650d6a6633775d8a70612">get_unary_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:accf153854ef650d6a6633775d8a70612"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aec97852a7d8938407122d21e78d66f5f" id="r_aec97852a7d8938407122d21e78d66f5f"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#aec97852a7d8938407122d21e78d66f5f">get_binary_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:aec97852a7d8938407122d21e78d66f5f"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a0a5effc3e1cfd4123b9a63c08e947e45" id="r_a0a5effc3e1cfd4123b9a63c08e947e45"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a0a5effc3e1cfd4123b9a63c08e947e45">get_binary_two_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a0a5effc3e1cfd4123b9a63c08e947e45"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a266558f20a72c439396ecd492a08d65f" id="r_a266558f20a72c439396ecd492a08d65f"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a266558f20a72c439396ecd492a08d65f">get_ternary_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a266558f20a72c439396ecd492a08d65f"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a15175e8e2b1e26726c63393e4d68b628" id="r_a15175e8e2b1e26726c63393e4d68b628"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a15175e8e2b1e26726c63393e4d68b628">get_unary_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, <a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a> out_type, const std::string <a class="el" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>)</td></tr>
|
||||
<tr class="separator:a15175e8e2b1e26726c63393e4d68b628"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a4decd4a07d91487e6903f6e3c8b7513a" id="r_a4decd4a07d91487e6903f6e3c8b7513a"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a4decd4a07d91487e6903f6e3c8b7513a">get_binary_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, <a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a> in_type, <a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a> out_type, const std::string <a class="el" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>)</td></tr>
|
||||
<tr class="separator:a4decd4a07d91487e6903f6e3c8b7513a"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a4e809746f48e5dcf7fa63215d3f5e33e" id="r_a4e809746f48e5dcf7fa63215d3f5e33e"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a4e809746f48e5dcf7fa63215d3f5e33e">get_binary_two_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, <a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a> in_type, <a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a> out_type, const std::string <a class="el" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>)</td></tr>
|
||||
<tr class="separator:a4e809746f48e5dcf7fa63215d3f5e33e"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a54eb3b65375022428aab5f810e40624b" id="r_a54eb3b65375022428aab5f810e40624b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a54eb3b65375022428aab5f810e40624b">get_ternary_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, <a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a> type, const std::string <a class="el" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>)</td></tr>
|
||||
<tr class="separator:a54eb3b65375022428aab5f810e40624b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a05a220cff45f12439fde775983c6df78" id="r_a05a220cff45f12439fde775983c6df78"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a05a220cff45f12439fde775983c6df78">get_copy_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a05a220cff45f12439fde775983c6df78"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a35a412f688d79eb47e42d20a7c8650ee" id="r_a35a412f688d79eb47e42d20a7c8650ee"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a35a412f688d79eb47e42d20a7c8650ee">get_softmax_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, bool precise, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
@@ -493,8 +518,13 @@ Functions</h2></td></tr>
|
||||
<tr class="separator:adce79d220672f5f3c65cc31d145ca9c4"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:abce2b67044ee06a7bbe7a91ec7c8c48d" id="r_abce2b67044ee06a7bbe7a91ec7c8c48d"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#abce2b67044ee06a7bbe7a91ec7c8c48d">get_steel_conv_general_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, int bm, int bn, int bk, int wm, int wn)</td></tr>
|
||||
<tr class="separator:abce2b67044ee06a7bbe7a91ec7c8c48d"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a4d8800bb9892b04684c78e3e5c760983" id="r_a4d8800bb9892b04684c78e3e5c760983"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a4d8800bb9892b04684c78e3e5c760983">get_fft_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &hash_name, const int tg_mem_size, const std::string &in_type, const std::string &out_type, int step, bool real, const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &func_consts)</td></tr>
|
||||
<tr class="separator:a4d8800bb9892b04684c78e3e5c760983"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a1d4cffc3c78067b3d9a62d64f3fb686f" id="r_a1d4cffc3c78067b3d9a62d64f3fb686f"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a1d4cffc3c78067b3d9a62d64f3fb686f">get_fft_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &hash_name, const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &func_consts, const std::string &template_def)</td></tr>
|
||||
<tr class="separator:a1d4cffc3c78067b3d9a62d64f3fb686f"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aa3faeae5378bfaafe3ce3432a051e43e" id="r_aa3faeae5378bfaafe3ce3432a051e43e"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#aa3faeae5378bfaafe3ce3432a051e43e">get_quantized_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &template_def)</td></tr>
|
||||
<tr class="separator:aa3faeae5378bfaafe3ce3432a051e43e"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aae0d19f0acdef2accd2428fb84c8a032" id="r_aae0d19f0acdef2accd2428fb84c8a032"><td class="memTemplParams" colspan="2">template<typename... Args> </td></tr>
|
||||
<tr class="memitem:aae0d19f0acdef2accd2428fb84c8a032"><td class="memTemplItemLeft" align="right" valign="top">std::string </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aae0d19f0acdef2accd2428fb84c8a032">get_template_definition</a> (std::string name, std::string func, Args... args)</td></tr>
|
||||
<tr class="separator:aae0d19f0acdef2accd2428fb84c8a032"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:afe8386cea0c7b846dc78649927fd0c75" id="r_afe8386cea0c7b846dc78649927fd0c75"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#afe8386cea0c7b846dc78649927fd0c75">steel_matmul_conv_groups</a> (const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &a, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &b, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, int M, int N, int K, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, int groups, std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &copies)</td></tr>
|
||||
<tr class="separator:afe8386cea0c7b846dc78649927fd0c75"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ab43a7633794498e1c6775cca829eb886" id="r_ab43a7633794498e1c6775cca829eb886"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#ab43a7633794498e1c6775cca829eb886">steel_matmul</a> (const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &a, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &b, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})</td></tr>
|
||||
@@ -683,6 +713,9 @@ Functions</h2></td></tr>
|
||||
<tr class="memitem:gaa6adbc9c86f0ab27d8810a02e9e719fd" id="r_gaa6adbc9c86f0ab27d8810a02e9e719fd"><td class="memItemLeft" align="right" valign="top"><a class="el" href="classmlx_1_1core_1_1array.html">array</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="group__ops.html#gaa6adbc9c86f0ab27d8810a02e9e719fd">flatten</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &a, <a class="el" href="#a95fc1013cc48fbfee0c54310711a5e58">StreamOrDevice</a> s={})</td></tr>
|
||||
<tr class="memdesc:gaa6adbc9c86f0ab27d8810a02e9e719fd"><td class="mdescLeft"> </td><td class="mdescRight">Flatten the array to 1D. <br /></td></tr>
|
||||
<tr class="separator:gaa6adbc9c86f0ab27d8810a02e9e719fd"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ga001919c0ee4a9c3d7948ed32cb4c58d6" id="r_ga001919c0ee4a9c3d7948ed32cb4c58d6"><td class="memItemLeft" align="right" valign="top"><a class="el" href="classmlx_1_1core_1_1array.html">array</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="group__ops.html#ga001919c0ee4a9c3d7948ed32cb4c58d6">hadamard_transform</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &a, float scale=1.0f, <a class="el" href="#a95fc1013cc48fbfee0c54310711a5e58">StreamOrDevice</a> s={})</td></tr>
|
||||
<tr class="memdesc:ga001919c0ee4a9c3d7948ed32cb4c58d6"><td class="mdescLeft"> </td><td class="mdescRight"><a class="el" href="classmlx_1_1core_1_1_multiply.html">Multiply</a> the array by the <a class="el" href="classmlx_1_1core_1_1_hadamard.html">Hadamard</a> matrix of corresponding size. <br /></td></tr>
|
||||
<tr class="separator:ga001919c0ee4a9c3d7948ed32cb4c58d6"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ga710daa7ec721bd4d3f326082cb195576" id="r_ga710daa7ec721bd4d3f326082cb195576"><td class="memItemLeft" align="right" valign="top"><a class="el" href="classmlx_1_1core_1_1array.html">array</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="group__ops.html#ga710daa7ec721bd4d3f326082cb195576">squeeze</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &a, const std::vector< int > &axes, <a class="el" href="#a95fc1013cc48fbfee0c54310711a5e58">StreamOrDevice</a> s={})</td></tr>
|
||||
<tr class="memdesc:ga710daa7ec721bd4d3f326082cb195576"><td class="mdescLeft"> </td><td class="mdescRight">Remove singleton dimensions at the given axes. <br /></td></tr>
|
||||
<tr class="separator:ga710daa7ec721bd4d3f326082cb195576"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@@ -2196,6 +2229,10 @@ Functions</h2></td></tr>
|
||||
<tr class="separator:a57eb97a5eba99a846ac429795e407574"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a7db909d54cf07375e89424c32c07a29c" id="r_a7db909d54cf07375e89424c32c07a29c"><td class="memItemLeft" align="right" valign="top">std::ostream & </td><td class="memItemRight" valign="bottom"><a class="el" href="#a7db909d54cf07375e89424c32c07a29c">operator<<</a> (std::ostream &os, const <a class="el" href="#acb5d16c9b83778c7621c38e522e0060b">bfloat16_t</a> &v)</td></tr>
|
||||
<tr class="separator:a7db909d54cf07375e89424c32c07a29c"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:adacbc4526e8964b267a8ec3eb1bc1a32" id="r_adacbc4526e8964b267a8ec3eb1bc1a32"><td class="memItemLeft" align="right" valign="top">bool </td><td class="memItemRight" valign="bottom"><a class="el" href="#adacbc4526e8964b267a8ec3eb1bc1a32">is_power_of_2</a> (int n)</td></tr>
|
||||
<tr class="separator:adacbc4526e8964b267a8ec3eb1bc1a32"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a685c0530e338aabc622325685846ce93" id="r_a685c0530e338aabc622325685846ce93"><td class="memItemLeft" align="right" valign="top">int </td><td class="memItemRight" valign="bottom"><a class="el" href="#a685c0530e338aabc622325685846ce93">next_power_of_2</a> (int n)</td></tr>
|
||||
<tr class="separator:a685c0530e338aabc622325685846ce93"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="var-members" name="var-members"></a>
|
||||
Variables</h2></td></tr>
|
||||
@@ -2205,6 +2242,12 @@ Variables</h2></td></tr>
|
||||
<tr class="memitem:a94c1057929b390e5613304afa16dfbda" id="r_a94c1057929b390e5613304afa16dfbda"><td class="memTemplParams" colspan="2">template<typename... T> </td></tr>
|
||||
<tr class="memitem:a94c1057929b390e5613304afa16dfbda"><td class="memTemplItemLeft" align="right" valign="top">constexpr bool </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a94c1057929b390e5613304afa16dfbda">is_arrays_v</a> = (<a class="el" href="#a01b0d64a75dfa2e95d6c7b5c53d708af">is_array_v</a><T> && ...)</td></tr>
|
||||
<tr class="separator:a94c1057929b390e5613304afa16dfbda"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a4beeeec4413be7adcfb14feaa9cf0e2e" id="r_a4beeeec4413be7adcfb14feaa9cf0e2e"><td class="memItemLeft" align="right" valign="top">constexpr std::string_view </td><td class="memItemRight" valign="bottom"><a class="el" href="#a4beeeec4413be7adcfb14feaa9cf0e2e">h12</a></td></tr>
|
||||
<tr class="separator:a4beeeec4413be7adcfb14feaa9cf0e2e"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a862c6b94fec384c34a699ced64d01404" id="r_a862c6b94fec384c34a699ced64d01404"><td class="memItemLeft" align="right" valign="top">constexpr std::string_view </td><td class="memItemRight" valign="bottom"><a class="el" href="#a862c6b94fec384c34a699ced64d01404">h20</a></td></tr>
|
||||
<tr class="separator:a862c6b94fec384c34a699ced64d01404"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ac447ad59592dd06435adca7df37e33ad" id="r_ac447ad59592dd06435adca7df37e33ad"><td class="memItemLeft" align="right" valign="top">constexpr std::string_view </td><td class="memItemRight" valign="bottom"><a class="el" href="#ac447ad59592dd06435adca7df37e33ad">h28</a></td></tr>
|
||||
<tr class="separator:ac447ad59592dd06435adca7df37e33ad"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ab93149e46a6d8f3e1988123fab508dc2" id="r_ab93149e46a6d8f3e1988123fab508dc2"><td class="memItemLeft" align="right" valign="top">std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &) </td><td class="memItemRight" valign="bottom"><a class="el" href="#ab93149e46a6d8f3e1988123fab508dc2">compile</a> )(const std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> &fun, bool shapeless=false)</td></tr>
|
||||
<tr class="memdesc:ab93149e46a6d8f3e1988123fab508dc2"><td class="mdescLeft"> </td><td class="mdescRight">Compile takes a function and returns a compiled function. <br /></td></tr>
|
||||
<tr class="separator:ab93149e46a6d8f3e1988123fab508dc2"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@@ -2259,8 +2302,11 @@ Variables</h2></td></tr>
|
||||
<tr class="memitem:a933289d4688479e1c4d8ba04332c406b" id="r_a933289d4688479e1c4d8ba04332c406b"><td class="memItemLeft" align="right" valign="top">std::function< <a class="el" href="classmlx_1_1core_1_1array.html">array</a>(const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &) </td><td class="memItemRight" valign="bottom"><a class="el" href="#a933289d4688479e1c4d8ba04332c406b">vmap</a> )(const std::function< <a class="el" href="classmlx_1_1core_1_1array.html">array</a>(const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &)> &fun, int in_axis=0, int out_axis=0)</td></tr>
|
||||
<tr class="memdesc:a933289d4688479e1c4d8ba04332c406b"><td class="mdescLeft"> </td><td class="mdescRight">Automatically vectorize a unary function over the requested axes. <br /></td></tr>
|
||||
<tr class="separator:a933289d4688479e1c4d8ba04332c406b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a58c5b448f357b48e86599eb8eeea141d" id="r_a58c5b448f357b48e86599eb8eeea141d"><td class="memItemLeft" align="right" valign="top">std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &) </td><td class="memItemRight" valign="bottom"><a class="el" href="#a58c5b448f357b48e86599eb8eeea141d">custom_function</a> )(std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> fun, std::optional< std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> > fun_vjp=std::nullopt, std::optional< std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< int > &)> > fun_jvp=std::nullopt, std::optional< std::function< std::pair< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >, std::vector< int > >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< int > &)> > fun_vmap=std::nullopt)</td></tr>
|
||||
<tr class="memdesc:a58c5b448f357b48e86599eb8eeea141d"><td class="mdescLeft"> </td><td class="mdescRight">Redefine the transformations of <code>fun</code> according to the provided functions. <br /></td></tr>
|
||||
<tr class="separator:a58c5b448f357b48e86599eb8eeea141d"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3fa1f0ad360f3e16c146384276b1c467" id="r_a3fa1f0ad360f3e16c146384276b1c467"><td class="memItemLeft" align="right" valign="top">std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &) </td><td class="memItemRight" valign="bottom"><a class="el" href="#a3fa1f0ad360f3e16c146384276b1c467">custom_vjp</a> )(std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> fun, std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> fun_vjp)</td></tr>
|
||||
<tr class="memdesc:a3fa1f0ad360f3e16c146384276b1c467"><td class="mdescLeft"> </td><td class="mdescRight">Return the results of calling fun with args but if their vjp is computed it will be computed by fun_vjp. <br /></td></tr>
|
||||
<tr class="memdesc:a3fa1f0ad360f3e16c146384276b1c467"><td class="mdescLeft"> </td><td class="mdescRight">Return a function that behaves exactly like <code>fun</code> but if the vjp of the results is computed <code>fun_vjp</code> will be used instead of <code>vjp(fun, ...)</code> . <br /></td></tr>
|
||||
<tr class="separator:a3fa1f0ad360f3e16c146384276b1c467"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a26127b71b2ec65c51d7627e71847083d" id="r_a26127b71b2ec65c51d7627e71847083d"><td class="memItemLeft" align="right" valign="top">std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &) </td><td class="memItemRight" valign="bottom"><a class="el" href="#a26127b71b2ec65c51d7627e71847083d">checkpoint</a> )(std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> fun)</td></tr>
|
||||
<tr class="memdesc:a26127b71b2ec65c51d7627e71847083d"><td class="mdescLeft"> </td><td class="mdescRight">Checkpoint the gradient of a function. <br /></td></tr>
|
||||
@@ -3293,6 +3339,31 @@ template<typename stride_t > </div>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a3a8fe7ba84714dbb5fdc81e93a07abc8" name="a3a8fe7ba84714dbb5fdc81e93a07abc8"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a3a8fe7ba84714dbb5fdc81e93a07abc8">◆ </a></span>decompose_hadamard()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">std::pair< int, int > mlx::core::decompose_hadamard </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">int</td> <td class="paramname"><span class="paramname"><em>n</em></span></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a0196171cfe6ee2953113abce597dc815" name="a0196171cfe6ee2953113abce597dc815"></a>
|
||||
@@ -3576,8 +3647,8 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aec97852a7d8938407122d21e78d66f5f" name="aec97852a7d8938407122d21e78d66f5f"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aec97852a7d8938407122d21e78d66f5f">◆ </a></span>get_binary_kernel()</h2>
|
||||
<a id="a4decd4a07d91487e6903f6e3c8b7513a" name="a4decd4a07d91487e6903f6e3c8b7513a"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a4decd4a07d91487e6903f6e3c8b7513a">◆ </a></span>get_binary_kernel()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -3595,20 +3666,25 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>in</em>, </span></td>
|
||||
<td class="paramtype"><a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a></td> <td class="paramname"><span class="paramname"><em>in_type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>out</em></span> )</td>
|
||||
<td class="paramtype"><a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a></td> <td class="paramname"><span class="paramname"><em>out_type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string</td> <td class="paramname"><span class="paramname"><em>op</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a0a5effc3e1cfd4123b9a63c08e947e45" name="a0a5effc3e1cfd4123b9a63c08e947e45"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a0a5effc3e1cfd4123b9a63c08e947e45">◆ </a></span>get_binary_two_kernel()</h2>
|
||||
<a id="a4e809746f48e5dcf7fa63215d3f5e33e" name="a4e809746f48e5dcf7fa63215d3f5e33e"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a4e809746f48e5dcf7fa63215d3f5e33e">◆ </a></span>get_binary_two_kernel()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -3626,12 +3702,17 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>in</em>, </span></td>
|
||||
<td class="paramtype"><a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a></td> <td class="paramname"><span class="paramname"><em>in_type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>out</em></span> )</td>
|
||||
<td class="paramtype"><a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a></td> <td class="paramname"><span class="paramname"><em>out_type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string</td> <td class="paramname"><span class="paramname"><em>op</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
@@ -3669,8 +3750,8 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a4d8800bb9892b04684c78e3e5c760983" name="a4d8800bb9892b04684c78e3e5c760983"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a4d8800bb9892b04684c78e3e5c760983">◆ </a></span>get_fft_kernel()</h2>
|
||||
<a id="a1d4cffc3c78067b3d9a62d64f3fb686f" name="a1d4cffc3c78067b3d9a62d64f3fb686f"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a1d4cffc3c78067b3d9a62d64f3fb686f">◆ </a></span>get_fft_kernel()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -3693,32 +3774,12 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>tg_mem_size</em>, </span></td>
|
||||
<td class="paramtype">const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &</td> <td class="paramname"><span class="paramname"><em>func_consts</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>in_type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>out_type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">int</td> <td class="paramname"><span class="paramname"><em>step</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">bool</td> <td class="paramname"><span class="paramname"><em>real</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &</td> <td class="paramname"><span class="paramname"><em>func_consts</em></span> )</td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>template_def</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
@@ -3764,6 +3825,32 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aa3faeae5378bfaafe3ce3432a051e43e" name="aa3faeae5378bfaafe3ce3432a051e43e"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aa3faeae5378bfaafe3ce3432a051e43e">◆ </a></span>get_quantized_kernel()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">MTL::ComputePipelineState * mlx::core::get_quantized_kernel </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &</td> <td class="paramname"><span class="paramname"><em>d</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>kernel_name</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>template_def</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a51c4bb09230348bd0252e22bfdc9bc89" name="a51c4bb09230348bd0252e22bfdc9bc89"></a>
|
||||
@@ -3826,6 +3913,27 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a1555dc378c5254e79199421761f26f2b" name="a1555dc378c5254e79199421761f26f2b"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a1555dc378c5254e79199421761f26f2b">◆ </a></span>get_reduction_plan()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> mlx::core::get_reduction_plan </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>x</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::vector< int ></td> <td class="paramname"><span class="paramname"><em>axes</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aeefaff208444d3fa61ecc0946fe1de5f" name="aeefaff208444d3fa61ecc0946fe1de5f"></a>
|
||||
@@ -4339,8 +4447,36 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a266558f20a72c439396ecd492a08d65f" name="a266558f20a72c439396ecd492a08d65f"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a266558f20a72c439396ecd492a08d65f">◆ </a></span>get_ternary_kernel()</h2>
|
||||
<a id="aae0d19f0acdef2accd2428fb84c8a032" name="aae0d19f0acdef2accd2428fb84c8a032"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aae0d19f0acdef2accd2428fb84c8a032">◆ </a></span>get_template_definition()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename... Args> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">std::string mlx::core::get_template_definition </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">std::string</td> <td class="paramname"><span class="paramname"><em>name</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">std::string</td> <td class="paramname"><span class="paramname"><em>func</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">Args...</td> <td class="paramname"><span class="paramname"><em>args</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a54eb3b65375022428aab5f810e40624b" name="a54eb3b65375022428aab5f810e40624b"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a54eb3b65375022428aab5f810e40624b">◆ </a></span>get_ternary_kernel()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -4358,7 +4494,12 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>out</em></span> )</td>
|
||||
<td class="paramtype"><a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a></td> <td class="paramname"><span class="paramname"><em>type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string</td> <td class="paramname"><span class="paramname"><em>op</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
@@ -4382,8 +4523,8 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="accf153854ef650d6a6633775d8a70612" name="accf153854ef650d6a6633775d8a70612"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#accf153854ef650d6a6633775d8a70612">◆ </a></span>get_unary_kernel()</h2>
|
||||
<a id="a15175e8e2b1e26726c63393e4d68b628" name="a15175e8e2b1e26726c63393e4d68b628"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a15175e8e2b1e26726c63393e4d68b628">◆ </a></span>get_unary_kernel()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -4401,7 +4542,12 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>out</em></span> )</td>
|
||||
<td class="paramtype"><a class="el" href="structmlx_1_1core_1_1_dtype.html">Dtype</a></td> <td class="paramname"><span class="paramname"><em>out_type</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string</td> <td class="paramname"><span class="paramname"><em>op</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
@@ -4486,6 +4632,31 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
<p>Returns a function which computes the gradient of the input function with respect to a single input array. </p>
|
||||
<p>The function being differentiated takes a vector of arrays and returns an array. The optional <code>argnum</code> index specifies which the argument to compute the gradient with respect to and defaults to 0. </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a50214cf406957fab27c8bef32046f030" name="a50214cf406957fab27c8bef32046f030"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a50214cf406957fab27c8bef32046f030">◆ </a></span>hadamard_matrices()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">const std::map< int, std::string_view > mlx::core::hadamard_matrices </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"><span class="paramname"></span></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a625ed440df3fa57318017c1f2c589efe" name="a625ed440df3fa57318017c1f2c589efe"></a>
|
||||
@@ -4511,6 +4682,31 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="adacbc4526e8964b267a8ec3eb1bc1a32" name="adacbc4526e8964b267a8ec3eb1bc1a32"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#adacbc4526e8964b267a8ec3eb1bc1a32">◆ </a></span>is_power_of_2()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">bool mlx::core::is_power_of_2 </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">int</td> <td class="paramname"><span class="paramname"><em>n</em></span></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="ad4b664de4a4abd305827b30879b9da33" name="ad4b664de4a4abd305827b30879b9da33"></a>
|
||||
@@ -4849,6 +5045,51 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
|
||||
<p><a class="el" href="classmlx_1_1core_1_1_load.html">Load</a> array map from .safetensors file format. </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="ac9fb1286a1a00395e901dbff80560895" name="ac9fb1286a1a00395e901dbff80560895"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ac9fb1286a1a00395e901dbff80560895">◆ </a></span>make_contiguous_strides()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename stride_t > </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">std::vector< stride_t > mlx::core::make_contiguous_strides </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const std::vector< int > &</td> <td class="paramname"><span class="paramname"><em>shape</em></span></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a9a9254ce9975ec247a2718bc02d6f201" name="a9a9254ce9975ec247a2718bc02d6f201"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a9a9254ce9975ec247a2718bc02d6f201">◆ </a></span>nd_loop()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void mlx::core::nd_loop </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">std::function< void(int)></td> <td class="paramname"><span class="paramname"><em>callback</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::vector< int > &</td> <td class="paramname"><span class="paramname"><em>shape</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::vector< size_t > &</td> <td class="paramname"><span class="paramname"><em>strides</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a6f7c63a9be10337b3b96d527e1db3c2f" name="a6f7c63a9be10337b3b96d527e1db3c2f"></a>
|
||||
@@ -4868,6 +5109,31 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
|
||||
<p>Make a new stream on the given device. </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a685c0530e338aabc622325685846ce93" name="a685c0530e338aabc622325685846ce93"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a685c0530e338aabc622325685846ce93">◆ </a></span>next_power_of_2()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">int mlx::core::next_power_of_2 </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">int</td> <td class="paramname"><span class="paramname"><em>n</em></span></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a0181b5d72bf3d34448dabc70f7ff858d" name="a0181b5d72bf3d34448dabc70f7ff858d"></a>
|
||||
@@ -15810,6 +16076,92 @@ template<typename T > </div>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a064d61b6ddc9e5d1e261a7e33de71083" name="a064d61b6ddc9e5d1e261a7e33de71083"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a064d61b6ddc9e5d1e261a7e33de71083">◆ </a></span>reduction_op() <span class="overload">[1/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op > </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void mlx::core::reduction_op </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>x</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype"><a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>out</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::vector< int > &</td> <td class="paramname"><span class="paramname"><em>axes</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">U</td> <td class="paramname"><span class="paramname"><em>init</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">Op</td> <td class="paramname"><span class="paramname"><em>op</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aa08ffc1e8f2c58afb2d463496f827ef0" name="aa08ffc1e8f2c58afb2d463496f827ef0"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aa08ffc1e8f2c58afb2d463496f827ef0">◆ </a></span>reduction_op() <span class="overload">[2/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename OpS , typename OpC , typename Op > </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void mlx::core::reduction_op </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>x</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype"><a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>out</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::vector< int > &</td> <td class="paramname"><span class="paramname"><em>axes</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">U</td> <td class="paramname"><span class="paramname"><em>init</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">OpS</td> <td class="paramname"><span class="paramname"><em>ops</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">OpC</td> <td class="paramname"><span class="paramname"><em>opc</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">Op</td> <td class="paramname"><span class="paramname"><em>op</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a8b984eef832f757e28cd262d64a49ae7" name="a8b984eef832f757e28cd262d64a49ae7"></a>
|
||||
@@ -16122,6 +16474,27 @@ template<typename T > </div>
|
||||
|
||||
<p>Make the stream the default for its device. </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a44c3ea6db6553c3f6552b9ba64a69494" name="a44c3ea6db6553c3f6552b9ba64a69494"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a44c3ea6db6553c3f6552b9ba64a69494">◆ </a></span>shapes_without_reduction_axes()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">std::pair< std::vector< int >, std::vector< size_t > > mlx::core::shapes_without_reduction_axes </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &</td> <td class="paramname"><span class="paramname"><em>x</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::vector< int > &</td> <td class="paramname"><span class="paramname"><em>axes</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a427f2c255dfc6e1f83f97587b08e71bc" name="a427f2c255dfc6e1f83f97587b08e71bc"></a>
|
||||
@@ -17056,6 +17429,56 @@ template<typename T > </div>
|
||||
<b>Initial value:</b><div class="fragment"><div class="line">=</div>
|
||||
<div class="line"> Dtype::Category::complexfloating</div>
|
||||
</div><!-- fragment -->
|
||||
</div>
|
||||
</div>
|
||||
<a id="a58c5b448f357b48e86599eb8eeea141d" name="a58c5b448f357b48e86599eb8eeea141d"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a58c5b448f357b48e86599eb8eeea141d">◆ </a></span>custom_function</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &) mlx::core::custom_function) (std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> fun, std::optional< std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >( const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)> > fun_vjp=std::nullopt, std::optional< std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >( const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< int > &)> > fun_jvp=std::nullopt, std::optional< std::function< std::pair< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >, std::vector< int > >( const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< int > &)> > fun_vmap=std::nullopt) </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)></td> <td class="paramname"><span class="paramname"><em>fun</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">std::optional< std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &)></td> <td class="paramname"><span class="paramname">, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">fun_vjp</td> <td class="paramname"><span class="paramname"><span class="paramdefsep"> = </span><span class="paramdefval">std::nullopt</span>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">std::optional< std::function< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< int > &)></td> <td class="paramname"><span class="paramname">, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">fun_jvp</td> <td class="paramname"><span class="paramname"><span class="paramdefsep"> = </span><span class="paramdefval">std::nullopt</span>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">std::optional< std::function< std::pair< std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> >, std::vector< int > >(const std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &, const std::vector< int > &)></td> <td class="paramname"><span class="paramname">, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">fun_vmap</td> <td class="paramname"><span class="paramname"><span class="paramdefsep"> = </span><span class="paramdefval">std::nullopt</span></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
<p>Redefine the transformations of <code>fun</code> according to the provided functions. </p>
|
||||
<p>Namely when calling the vjp of <code>fun</code> then <code>fun_vjp</code> will be called, <code>fun_jvp</code> for the jvp and <code>fun_vmap</code> for vmap.</p>
|
||||
<p>If any transformation is not provided, then a default one is created by calling <code>vjp</code>, <code>jvp</code> and <code>vmap</code> on the function directly. </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a3fa1f0ad360f3e16c146384276b1c467" name="a3fa1f0ad360f3e16c146384276b1c467"></a>
|
||||
@@ -17077,7 +17500,7 @@ template<typename T > </div>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
<p>Return the results of calling fun with args but if their vjp is computed it will be computed by fun_vjp. </p>
|
||||
<p>Return a function that behaves exactly like <code>fun</code> but if the vjp of the results is computed <code>fun_vjp</code> will be used instead of <code>vjp(fun, ...)</code> . </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@@ -17224,6 +17647,138 @@ template<typename T > </div>
|
||||
<p>The function being differentiated takes a vector of arrays and returns an array. The vector of <code>argnums</code> specifies which the arguments to compute the gradient with respect to. At least one argument must be specified.</p>
|
||||
<p>The function being differentiated takes a vector of arrays and returns an array. The optional <code>argnum</code> index specifies which the argument to compute the gradient with respect to and defaults to 0. </p>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a4beeeec4413be7adcfb14feaa9cf0e2e" name="a4beeeec4413be7adcfb14feaa9cf0e2e"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a4beeeec4413be7adcfb14feaa9cf0e2e">◆ </a></span>h12</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">constexpr std::string_view mlx::core::h12</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">constexpr</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
<b>Initial value:</b><div class="fragment"><div class="line">= R<span class="stringliteral">"(</span></div>
|
||||
<div class="line"><span class="stringliteral">+-++++++++++</span></div>
|
||||
<div class="line"><span class="stringliteral">--+-+-+-+-+-</span></div>
|
||||
<div class="line"><span class="stringliteral">+++-++----++</span></div>
|
||||
<div class="line"><span class="stringliteral">+---+--+-++-</span></div>
|
||||
<div class="line"><span class="stringliteral">+++++-++----</span></div>
|
||||
<div class="line"><span class="stringliteral">+-+---+--+-+</span></div>
|
||||
<div class="line"><span class="stringliteral">++--+++-++--</span></div>
|
||||
<div class="line"><span class="stringliteral">+--++---+--+</span></div>
|
||||
<div class="line"><span class="stringliteral">++----+++-++</span></div>
|
||||
<div class="line"><span class="stringliteral">+--+-++---+-</span></div>
|
||||
<div class="line"><span class="stringliteral">++++----+++-</span></div>
|
||||
<div class="line"><span class="stringliteral">+-+--+-++---</span></div>
|
||||
<div class="line"><span class="stringliteral">)"</span></div>
|
||||
</div><!-- fragment -->
|
||||
</div>
|
||||
</div>
|
||||
<a id="a862c6b94fec384c34a699ced64d01404" name="a862c6b94fec384c34a699ced64d01404"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a862c6b94fec384c34a699ced64d01404">◆ </a></span>h20</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">constexpr std::string_view mlx::core::h20</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">constexpr</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
<b>Initial value:</b><div class="fragment"><div class="line">= R<span class="stringliteral">"(</span></div>
|
||||
<div class="line"><span class="stringliteral">+----+----++--++-++-</span></div>
|
||||
<div class="line"><span class="stringliteral">-+----+---+++---+-++</span></div>
|
||||
<div class="line"><span class="stringliteral">--+----+---+++-+-+-+</span></div>
|
||||
<div class="line"><span class="stringliteral">---+----+---+++++-+-</span></div>
|
||||
<div class="line"><span class="stringliteral">----+----++--++-++-+</span></div>
|
||||
<div class="line"><span class="stringliteral">-+++++-----+--+++--+</span></div>
|
||||
<div class="line"><span class="stringliteral">+-+++-+---+-+--+++--</span></div>
|
||||
<div class="line"><span class="stringliteral">++-++--+---+-+--+++-</span></div>
|
||||
<div class="line"><span class="stringliteral">+++-+---+---+-+--+++</span></div>
|
||||
<div class="line"><span class="stringliteral">++++-----++--+-+--++</span></div>
|
||||
<div class="line"><span class="stringliteral">--++-+-++-+-----++++</span></div>
|
||||
<div class="line"><span class="stringliteral">---++-+-++-+---+-+++</span></div>
|
||||
<div class="line"><span class="stringliteral">+---++-+-+--+--++-++</span></div>
|
||||
<div class="line"><span class="stringliteral">++---++-+----+-+++-+</span></div>
|
||||
<div class="line"><span class="stringliteral">-++---++-+----+++++-</span></div>
|
||||
<div class="line"><span class="stringliteral">-+--+--++-+----+----</span></div>
|
||||
<div class="line"><span class="stringliteral">+-+-----++-+----+---</span></div>
|
||||
<div class="line"><span class="stringliteral">-+-+-+---+--+----+--</span></div>
|
||||
<div class="line"><span class="stringliteral">--+-+++------+----+-</span></div>
|
||||
<div class="line"><span class="stringliteral">+--+--++------+----+</span></div>
|
||||
<div class="line"><span class="stringliteral">)"</span></div>
|
||||
</div><!-- fragment -->
|
||||
</div>
|
||||
</div>
|
||||
<a id="ac447ad59592dd06435adca7df37e33ad" name="ac447ad59592dd06435adca7df37e33ad"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ac447ad59592dd06435adca7df37e33ad">◆ </a></span>h28</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">constexpr std::string_view mlx::core::h28</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">constexpr</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
<b>Initial value:</b><div class="fragment"><div class="line">= R<span class="stringliteral">"(</span></div>
|
||||
<div class="line"><span class="stringliteral">+------++----++-+--+-+--++--</span></div>
|
||||
<div class="line"><span class="stringliteral">-+-----+++-----+-+--+-+--++-</span></div>
|
||||
<div class="line"><span class="stringliteral">--+-----+++---+-+-+----+--++</span></div>
|
||||
<div class="line"><span class="stringliteral">---+-----+++---+-+-+-+--+--+</span></div>
|
||||
<div class="line"><span class="stringliteral">----+-----+++---+-+-+++--+--</span></div>
|
||||
<div class="line"><span class="stringliteral">-----+-----++++--+-+--++--+-</span></div>
|
||||
<div class="line"><span class="stringliteral">------++----++-+--+-+--++--+</span></div>
|
||||
<div class="line"><span class="stringliteral">--++++-+-------++--+++-+--+-</span></div>
|
||||
<div class="line"><span class="stringliteral">---++++-+-----+-++--+-+-+--+</span></div>
|
||||
<div class="line"><span class="stringliteral">+---+++--+----++-++--+-+-+--</span></div>
|
||||
<div class="line"><span class="stringliteral">++---++---+----++-++--+-+-+-</span></div>
|
||||
<div class="line"><span class="stringliteral">+++---+----+----++-++--+-+-+</span></div>
|
||||
<div class="line"><span class="stringliteral">++++--------+-+--++-++--+-+-</span></div>
|
||||
<div class="line"><span class="stringliteral">-++++--------+++--++--+--+-+</span></div>
|
||||
<div class="line"><span class="stringliteral">-+-++-++--++--+--------++++-</span></div>
|
||||
<div class="line"><span class="stringliteral">+-+-++--+--++--+--------++++</span></div>
|
||||
<div class="line"><span class="stringliteral">-+-+-++--+--++--+----+---+++</span></div>
|
||||
<div class="line"><span class="stringliteral">+-+-+-++--+--+---+---++---++</span></div>
|
||||
<div class="line"><span class="stringliteral">++-+-+-++--+------+--+++---+</span></div>
|
||||
<div class="line"><span class="stringliteral">-++-+-+-++--+------+-++++---</span></div>
|
||||
<div class="line"><span class="stringliteral">+-++-+---++--+------+-++++--</span></div>
|
||||
<div class="line"><span class="stringliteral">-++--++-+-++-+++----++------</span></div>
|
||||
<div class="line"><span class="stringliteral">+-++--++-+-++-+++-----+-----</span></div>
|
||||
<div class="line"><span class="stringliteral">++-++---+-+-++-+++-----+----</span></div>
|
||||
<div class="line"><span class="stringliteral">-++-++-+-+-+-+--+++-----+---</span></div>
|
||||
<div class="line"><span class="stringliteral">--++-++++-+-+----+++-----+--</span></div>
|
||||
<div class="line"><span class="stringliteral">+--++-+-++-+-+----+++-----+-</span></div>
|
||||
<div class="line"><span class="stringliteral">++--++-+-++-+-+----++------+</span></div>
|
||||
<div class="line"><span class="stringliteral">)"</span></div>
|
||||
</div><!-- fragment -->
|
||||
</div>
|
||||
</div>
|
||||
<a id="a54c6fae21b7f2fea8e6f80011ef38534" name="a54c6fae21b7f2fea8e6f80011ef38534"></a>
|
||||
|
Reference in New Issue
Block a user