From 3e724a7c9876c0e9c6b6e56c9de453495caefa6e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 10 May 2024 08:49:36 -0700 Subject: [PATCH] docs update --- docs/build/html/.buildinfo | 2 +- docs/build/html/_sources/cpp/ops.rst | 3 +- docs/build/html/_sources/install.rst | 5 +- .../python/_autosummary/mlx.core.arctan2.rst | 6 + .../python/_autosummary/mlx.core.array.rst | 1 + .../_autosummary/mlx.core.bitwise_and.rst | 6 + .../_autosummary/mlx.core.bitwise_or.rst | 6 + .../_autosummary/mlx.core.bitwise_xor.rst | 6 + .../_autosummary/mlx.core.block_sparse_mm.rst | 6 + .../python/_autosummary/mlx.core.conj.rst | 6 + .../_autosummary/mlx.core.conjugate.rst | 6 + .../_autosummary/mlx.core.left_shift.rst | 6 + .../mlx.core.metal.device_info.rst | 6 + .../mlx.core.metal.reset_peak_memory.rst | 6 + .../_autosummary/mlx.core.right_shift.rst | 6 + .../mlx.optimizers.clip_grad_norm.rst | 6 + .../_autosummary/mlx.utils.tree_reduce.rst | 6 + docs/build/html/_sources/python/metal.rst | 2 + docs/build/html/_sources/python/ops.rst | 11 +- .../build/html/_sources/python/optimizers.rst | 7 + .../build/html/_sources/python/tree_utils.rst | 1 + .../html/_static/documentation_options.js | 2 +- docs/build/html/allocator_8h.html | 124 + docs/build/html/allocator_8h_source.html | 191 + docs/build/html/annotated.html | 445 + docs/build/html/arange_8h.html | 107 + docs/build/html/arange_8h_source.html | 192 + docs/build/html/array_8h.html | 138 + docs/build/html/array_8h_source.html | 842 + docs/build/html/atomic_8h.html | 521 + docs/build/html/atomic_8h_source.html | 478 + .../html/backend_2accelerate_2utils_8h.html | 107 + .../backend_2accelerate_2utils_8h_source.html | 134 + docs/build/html/backend_2common_2ops_8h.html | 234 + .../html/backend_2common_2ops_8h_source.html | 1218 ++ .../build/html/backend_2common_2utils_8h.html | 121 + .../backend_2common_2utils_8h_source.html | 236 + .../html/backend_2metal_2allocator_8h.html | 161 + .../backend_2metal_2allocator_8h_source.html | 221 + .../build/html/backend_2metal_2device_8h.html | 135 + .../backend_2metal_2device_8h_source.html | 389 + .../backend_2metal_2kernels_2bf16_8h.html | 10952 +++++++++++ ...ckend_2metal_2kernels_2bf16_8h_source.html | 489 + .../backend_2metal_2kernels_2complex_8h.html | 504 + ...nd_2metal_2kernels_2complex_8h_source.html | 276 + ...nd_2metal_2kernels_2reduction_2ops_8h.html | 116 + ...al_2kernels_2reduction_2ops_8h_source.html | 385 + ..._2metal_2kernels_2reduction_2utils_8h.html | 126 + ..._2kernels_2reduction_2utils_8h_source.html | 111 + ..._2kernels_2steel_2gemm_2transforms_8h.html | 114 + ...ls_2steel_2gemm_2transforms_8h_source.html | 192 + ...kend_2metal_2kernels_2steel_2utils_8h.html | 215 + ...etal_2kernels_2steel_2utils_8h_source.html | 142 + .../backend_2metal_2kernels_2utils_8h.html | 862 + ...kend_2metal_2kernels_2utils_8h_source.html | 488 + docs/build/html/backend_2metal_2utils_8h.html | 102 + .../html/backend_2metal_2utils_8h_source.html | 249 + docs/build/html/bc_s.png | Bin 0 -> 676 bytes docs/build/html/bc_sd.png | Bin 0 -> 635 bytes docs/build/html/bf16__math_8h.html | 594 + docs/build/html/bf16__math_8h_source.html | 498 + docs/build/html/binary__two_8h.html | 101 + docs/build/html/binary__two_8h_source.html | 646 + .../html/class_m_p_s_1_1_kernel-members.html | 92 + docs/build/html/class_m_p_s_1_1_kernel.html | 144 + docs/build/html/class_m_p_s_1_1_kernel.png | Bin 0 -> 656 bytes .../html/class_m_p_s_1_1_matrix-members.html | 93 + docs/build/html/class_m_p_s_1_1_matrix.html | 183 + docs/build/html/class_m_p_s_1_1_matrix.png | Bin 0 -> 665 bytes ...s_m_p_s_1_1_matrix_descriptor-members.html | 93 + .../class_m_p_s_1_1_matrix_descriptor.html | 221 + .../class_m_p_s_1_1_matrix_descriptor.png | Bin 0 -> 813 bytes ...p_s_1_1_matrix_multiplication-members.html | 98 + ...class_m_p_s_1_1_matrix_multiplication.html | 318 + .../class_m_p_s_1_1_matrix_multiplication.png | Bin 0 -> 1041 bytes ..._matrix_vector_multiplication-members.html | 93 + ..._p_s_1_1_matrix_vector_multiplication.html | 213 + ...m_p_s_1_1_matrix_vector_multiplication.png | Bin 0 -> 1145 bytes .../html/class_m_p_s_1_1_vector-members.html | 93 + docs/build/html/class_m_p_s_1_1_vector.html | 183 + docs/build/html/class_m_p_s_1_1_vector.png | Bin 0 -> 658 bytes ...s_m_p_s_1_1_vector_descriptor-members.html | 92 + .../class_m_p_s_1_1_vector_descriptor.html | 178 + .../class_m_p_s_1_1_vector_descriptor.png | Bin 0 -> 794 bytes docs/build/html/classes.html | 152 + .../classmlx_1_1core_1_1_abs-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_abs.html | 463 + docs/build/html/classmlx_1_1core_1_1_abs.png | Bin 0 -> 872 bytes .../classmlx_1_1core_1_1_add-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_add.html | 463 + docs/build/html/classmlx_1_1core_1_1_add.png | Bin 0 -> 874 bytes .../classmlx_1_1core_1_1_add_m_m-members.html | 115 + .../html/classmlx_1_1core_1_1_add_m_m.html | 404 + .../html/classmlx_1_1core_1_1_add_m_m.png | Bin 0 -> 905 bytes .../classmlx_1_1core_1_1_arange-members.html | 115 + .../html/classmlx_1_1core_1_1_arange.html | 332 + .../html/classmlx_1_1core_1_1_arange.png | Bin 0 -> 907 bytes .../classmlx_1_1core_1_1_arc_cos-members.html | 115 + .../html/classmlx_1_1core_1_1_arc_cos.html | 463 + .../html/classmlx_1_1core_1_1_arc_cos.png | Bin 0 -> 897 bytes ...classmlx_1_1core_1_1_arc_cosh-members.html | 115 + .../html/classmlx_1_1core_1_1_arc_cosh.html | 463 + .../html/classmlx_1_1core_1_1_arc_cosh.png | Bin 0 -> 909 bytes .../classmlx_1_1core_1_1_arc_sin-members.html | 115 + .../html/classmlx_1_1core_1_1_arc_sin.html | 463 + .../html/classmlx_1_1core_1_1_arc_sin.png | Bin 0 -> 895 bytes ...classmlx_1_1core_1_1_arc_sinh-members.html | 115 + .../html/classmlx_1_1core_1_1_arc_sinh.html | 463 + .../html/classmlx_1_1core_1_1_arc_sinh.png | Bin 0 -> 901 bytes .../classmlx_1_1core_1_1_arc_tan-members.html | 115 + .../html/classmlx_1_1core_1_1_arc_tan.html | 463 + .../html/classmlx_1_1core_1_1_arc_tan.png | Bin 0 -> 895 bytes ...classmlx_1_1core_1_1_arc_tan2-members.html | 115 + .../html/classmlx_1_1core_1_1_arc_tan2.html | 463 + .../html/classmlx_1_1core_1_1_arc_tan2.png | Bin 0 -> 913 bytes ...classmlx_1_1core_1_1_arc_tanh-members.html | 115 + .../html/classmlx_1_1core_1_1_arc_tanh.html | 463 + .../html/classmlx_1_1core_1_1_arc_tanh.png | Bin 0 -> 901 bytes ...mlx_1_1core_1_1_arg_partition-members.html | 115 + .../classmlx_1_1core_1_1_arg_partition.html | 391 + .../classmlx_1_1core_1_1_arg_partition.png | Bin 0 -> 936 bytes ...assmlx_1_1core_1_1_arg_reduce-members.html | 118 + .../html/classmlx_1_1core_1_1_arg_reduce.html | 418 + .../html/classmlx_1_1core_1_1_arg_reduce.png | Bin 0 -> 932 bytes ...classmlx_1_1core_1_1_arg_sort-members.html | 115 + .../html/classmlx_1_1core_1_1_arg_sort.html | 386 + .../html/classmlx_1_1core_1_1_arg_sort.png | Bin 0 -> 919 bytes ...assmlx_1_1core_1_1_as_strided-members.html | 115 + .../html/classmlx_1_1core_1_1_as_strided.html | 413 + .../html/classmlx_1_1core_1_1_as_strided.png | Bin 0 -> 917 bytes .../classmlx_1_1core_1_1_as_type-members.html | 115 + .../html/classmlx_1_1core_1_1_as_type.html | 467 + .../html/classmlx_1_1core_1_1_as_type.png | Bin 0 -> 918 bytes ...lx_1_1core_1_1_bitwise_binary-members.html | 121 + .../classmlx_1_1core_1_1_bitwise_binary.html | 422 + .../classmlx_1_1core_1_1_bitwise_binary.png | Bin 0 -> 937 bytes ..._1_1core_1_1_block_masked_m_m-members.html | 115 + ...classmlx_1_1core_1_1_block_masked_m_m.html | 365 + .../classmlx_1_1core_1_1_block_masked_m_m.png | Bin 0 -> 966 bytes ..._1_1core_1_1_block_sparse_m_m-members.html | 115 + ...classmlx_1_1core_1_1_block_sparse_m_m.html | 361 + .../classmlx_1_1core_1_1_block_sparse_m_m.png | Bin 0 -> 952 bytes ...lassmlx_1_1core_1_1_broadcast-members.html | 115 + .../html/classmlx_1_1core_1_1_broadcast.html | 437 + .../html/classmlx_1_1core_1_1_broadcast.png | Bin 0 -> 905 bytes .../classmlx_1_1core_1_1_ceil-members.html | 115 + .../build/html/classmlx_1_1core_1_1_ceil.html | 463 + docs/build/html/classmlx_1_1core_1_1_ceil.png | Bin 0 -> 864 bytes ...classmlx_1_1core_1_1_compiled-members.html | 108 + .../html/classmlx_1_1core_1_1_compiled.html | 493 + .../html/classmlx_1_1core_1_1_compiled.png | Bin 0 -> 546 bytes ...ssmlx_1_1core_1_1_concatenate-members.html | 115 + .../classmlx_1_1core_1_1_concatenate.html | 437 + .../html/classmlx_1_1core_1_1_concatenate.png | Bin 0 -> 914 bytes ...lassmlx_1_1core_1_1_conjugate-members.html | 115 + .../html/classmlx_1_1core_1_1_conjugate.html | 382 + .../html/classmlx_1_1core_1_1_conjugate.png | Bin 0 -> 929 bytes ...ssmlx_1_1core_1_1_convolution-members.html | 115 + .../classmlx_1_1core_1_1_convolution.html | 390 + .../html/classmlx_1_1core_1_1_convolution.png | Bin 0 -> 907 bytes .../classmlx_1_1core_1_1_copy-members.html | 115 + .../build/html/classmlx_1_1core_1_1_copy.html | 463 + docs/build/html/classmlx_1_1core_1_1_copy.png | Bin 0 -> 892 bytes .../classmlx_1_1core_1_1_cos-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_cos.html | 463 + docs/build/html/classmlx_1_1core_1_1_cos.png | Bin 0 -> 875 bytes .../classmlx_1_1core_1_1_cosh-members.html | 115 + .../build/html/classmlx_1_1core_1_1_cosh.html | 463 + docs/build/html/classmlx_1_1core_1_1_cosh.png | Bin 0 -> 888 bytes ...smlx_1_1core_1_1_custom_v_j_p-members.html | 107 + .../classmlx_1_1core_1_1_custom_v_j_p.html | 320 + .../classmlx_1_1core_1_1_custom_v_j_p.png | Bin 0 -> 575 bytes .../classmlx_1_1core_1_1_depends-members.html | 107 + .../html/classmlx_1_1core_1_1_depends.html | 316 + .../html/classmlx_1_1core_1_1_depends.png | Bin 0 -> 548 bytes .../classmlx_1_1core_1_1_div_mod-members.html | 107 + .../html/classmlx_1_1core_1_1_div_mod.html | 447 + .../html/classmlx_1_1core_1_1_div_mod.png | Bin 0 -> 536 bytes .../classmlx_1_1core_1_1_divide-members.html | 115 + .../html/classmlx_1_1core_1_1_divide.html | 463 + .../html/classmlx_1_1core_1_1_divide.png | Bin 0 -> 897 bytes .../classmlx_1_1core_1_1_equal-members.html | 115 + .../html/classmlx_1_1core_1_1_equal.html | 467 + .../build/html/classmlx_1_1core_1_1_equal.png | Bin 0 -> 893 bytes .../classmlx_1_1core_1_1_erf-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_erf.html | 463 + docs/build/html/classmlx_1_1core_1_1_erf.png | Bin 0 -> 861 bytes .../classmlx_1_1core_1_1_erf_inv-members.html | 115 + .../html/classmlx_1_1core_1_1_erf_inv.html | 463 + .../html/classmlx_1_1core_1_1_erf_inv.png | Bin 0 -> 880 bytes .../classmlx_1_1core_1_1_event-members.html | 99 + .../html/classmlx_1_1core_1_1_event.html | 320 + .../classmlx_1_1core_1_1_exp-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_exp.html | 463 + docs/build/html/classmlx_1_1core_1_1_exp.png | Bin 0 -> 875 bytes .../classmlx_1_1core_1_1_expm1-members.html | 115 + .../html/classmlx_1_1core_1_1_expm1.html | 434 + .../build/html/classmlx_1_1core_1_1_expm1.png | Bin 0 -> 883 bytes .../classmlx_1_1core_1_1_f_f_t-members.html | 115 + .../html/classmlx_1_1core_1_1_f_f_t.html | 447 + .../build/html/classmlx_1_1core_1_1_f_f_t.png | Bin 0 -> 847 bytes .../classmlx_1_1core_1_1_floor-members.html | 115 + .../html/classmlx_1_1core_1_1_floor.html | 463 + .../build/html/classmlx_1_1core_1_1_floor.png | Bin 0 -> 866 bytes .../classmlx_1_1core_1_1_full-members.html | 115 + .../build/html/classmlx_1_1core_1_1_full.html | 433 + docs/build/html/classmlx_1_1core_1_1_full.png | Bin 0 -> 852 bytes .../classmlx_1_1core_1_1_gather-members.html | 115 + .../html/classmlx_1_1core_1_1_gather.html | 442 + .../html/classmlx_1_1core_1_1_gather.png | Bin 0 -> 893 bytes .../classmlx_1_1core_1_1_greater-members.html | 115 + .../html/classmlx_1_1core_1_1_greater.html | 463 + .../html/classmlx_1_1core_1_1_greater.png | Bin 0 -> 910 bytes ...mlx_1_1core_1_1_greater_equal-members.html | 115 + .../classmlx_1_1core_1_1_greater_equal.html | 463 + .../classmlx_1_1core_1_1_greater_equal.png | Bin 0 -> 945 bytes .../classmlx_1_1core_1_1_inverse-members.html | 115 + .../html/classmlx_1_1core_1_1_inverse.html | 323 + .../html/classmlx_1_1core_1_1_inverse.png | Bin 0 -> 884 bytes .../classmlx_1_1core_1_1_less-members.html | 115 + .../build/html/classmlx_1_1core_1_1_less.html | 463 + docs/build/html/classmlx_1_1core_1_1_less.png | Bin 0 -> 867 bytes ...assmlx_1_1core_1_1_less_equal-members.html | 115 + .../html/classmlx_1_1core_1_1_less_equal.html | 463 + .../html/classmlx_1_1core_1_1_less_equal.png | Bin 0 -> 926 bytes .../classmlx_1_1core_1_1_load-members.html | 115 + .../build/html/classmlx_1_1core_1_1_load.html | 303 + docs/build/html/classmlx_1_1core_1_1_load.png | Bin 0 -> 872 bytes .../classmlx_1_1core_1_1_log-members.html | 119 + docs/build/html/classmlx_1_1core_1_1_log.html | 496 + docs/build/html/classmlx_1_1core_1_1_log.png | Bin 0 -> 866 bytes .../classmlx_1_1core_1_1_log1p-members.html | 115 + .../html/classmlx_1_1core_1_1_log1p.html | 434 + .../build/html/classmlx_1_1core_1_1_log1p.png | Bin 0 -> 884 bytes ...ssmlx_1_1core_1_1_log_add_exp-members.html | 115 + .../classmlx_1_1core_1_1_log_add_exp.html | 463 + .../html/classmlx_1_1core_1_1_log_add_exp.png | Bin 0 -> 943 bytes ...ssmlx_1_1core_1_1_logical_and-members.html | 115 + .../classmlx_1_1core_1_1_logical_and.html | 463 + .../html/classmlx_1_1core_1_1_logical_and.png | Bin 0 -> 930 bytes ...ssmlx_1_1core_1_1_logical_not-members.html | 115 + .../classmlx_1_1core_1_1_logical_not.html | 463 + .../html/classmlx_1_1core_1_1_logical_not.png | Bin 0 -> 918 bytes ...assmlx_1_1core_1_1_logical_or-members.html | 115 + .../html/classmlx_1_1core_1_1_logical_or.html | 463 + .../html/classmlx_1_1core_1_1_logical_or.png | Bin 0 -> 920 bytes .../classmlx_1_1core_1_1_matmul-members.html | 115 + .../html/classmlx_1_1core_1_1_matmul.html | 395 + .../html/classmlx_1_1core_1_1_matmul.png | Bin 0 -> 885 bytes .../classmlx_1_1core_1_1_maximum-members.html | 115 + .../html/classmlx_1_1core_1_1_maximum.html | 463 + .../html/classmlx_1_1core_1_1_maximum.png | Bin 0 -> 901 bytes .../classmlx_1_1core_1_1_minimum-members.html | 115 + .../html/classmlx_1_1core_1_1_minimum.html | 463 + .../html/classmlx_1_1core_1_1_minimum.png | Bin 0 -> 892 bytes ...classmlx_1_1core_1_1_multiply-members.html | 115 + .../html/classmlx_1_1core_1_1_multiply.html | 463 + .../html/classmlx_1_1core_1_1_multiply.png | Bin 0 -> 909 bytes ...classmlx_1_1core_1_1_negative-members.html | 115 + .../html/classmlx_1_1core_1_1_negative.html | 463 + .../html/classmlx_1_1core_1_1_negative.png | Bin 0 -> 929 bytes ...lassmlx_1_1core_1_1_not_equal-members.html | 115 + .../html/classmlx_1_1core_1_1_not_equal.html | 463 + .../html/classmlx_1_1core_1_1_not_equal.png | Bin 0 -> 916 bytes ..._1core_1_1_number_of_elements-members.html | 115 + ...assmlx_1_1core_1_1_number_of_elements.html | 396 + ...lassmlx_1_1core_1_1_number_of_elements.png | Bin 0 -> 991 bytes .../classmlx_1_1core_1_1_pad-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_pad.html | 447 + docs/build/html/classmlx_1_1core_1_1_pad.png | Bin 0 -> 874 bytes ...lassmlx_1_1core_1_1_partition-members.html | 115 + .../html/classmlx_1_1core_1_1_partition.html | 472 + .../html/classmlx_1_1core_1_1_partition.png | Bin 0 -> 888 bytes .../classmlx_1_1core_1_1_power-members.html | 115 + .../html/classmlx_1_1core_1_1_power.html | 463 + .../build/html/classmlx_1_1core_1_1_power.png | Bin 0 -> 900 bytes ...lassmlx_1_1core_1_1_primitive-members.html | 106 + .../html/classmlx_1_1core_1_1_primitive.html | 631 + .../html/classmlx_1_1core_1_1_primitive.png | Bin 0 -> 3642 bytes .../classmlx_1_1core_1_1_q_r_f-members.html | 107 + .../html/classmlx_1_1core_1_1_q_r_f.html | 273 + .../build/html/classmlx_1_1core_1_1_q_r_f.png | Bin 0 -> 520 bytes ..._1_1core_1_1_quantized_matmul-members.html | 115 + ...classmlx_1_1core_1_1_quantized_matmul.html | 447 + .../classmlx_1_1core_1_1_quantized_matmul.png | Bin 0 -> 975 bytes ...ssmlx_1_1core_1_1_random_bits-members.html | 115 + .../classmlx_1_1core_1_1_random_bits.html | 361 + .../html/classmlx_1_1core_1_1_random_bits.png | Bin 0 -> 920 bytes .../classmlx_1_1core_1_1_reduce-members.html | 122 + .../html/classmlx_1_1core_1_1_reduce.html | 472 + .../html/classmlx_1_1core_1_1_reduce.png | Bin 0 -> 895 bytes ...lassmlx_1_1core_1_1_remainder-members.html | 115 + .../html/classmlx_1_1core_1_1_remainder.html | 463 + .../html/classmlx_1_1core_1_1_remainder.png | Bin 0 -> 917 bytes .../classmlx_1_1core_1_1_reshape-members.html | 115 + .../html/classmlx_1_1core_1_1_reshape.html | 437 + .../html/classmlx_1_1core_1_1_reshape.png | Bin 0 -> 910 bytes .../classmlx_1_1core_1_1_round-members.html | 115 + .../html/classmlx_1_1core_1_1_round.html | 463 + .../build/html/classmlx_1_1core_1_1_round.png | Bin 0 -> 881 bytes .../classmlx_1_1core_1_1_s_v_d-members.html | 107 + .../html/classmlx_1_1core_1_1_s_v_d.html | 307 + .../build/html/classmlx_1_1core_1_1_s_v_d.png | Bin 0 -> 520 bytes .../classmlx_1_1core_1_1_scan-members.html | 120 + .../build/html/classmlx_1_1core_1_1_scan.html | 483 + docs/build/html/classmlx_1_1core_1_1_scan.png | Bin 0 -> 884 bytes .../classmlx_1_1core_1_1_scatter-members.html | 121 + .../html/classmlx_1_1core_1_1_scatter.html | 444 + .../html/classmlx_1_1core_1_1_scatter.png | Bin 0 -> 901 bytes .../classmlx_1_1core_1_1_select-members.html | 115 + .../html/classmlx_1_1core_1_1_select.html | 463 + .../html/classmlx_1_1core_1_1_select.png | Bin 0 -> 884 bytes .../classmlx_1_1core_1_1_sigmoid-members.html | 115 + .../html/classmlx_1_1core_1_1_sigmoid.html | 463 + .../html/classmlx_1_1core_1_1_sigmoid.png | Bin 0 -> 906 bytes .../classmlx_1_1core_1_1_sign-members.html | 115 + .../build/html/classmlx_1_1core_1_1_sign.html | 463 + docs/build/html/classmlx_1_1core_1_1_sign.png | Bin 0 -> 890 bytes .../classmlx_1_1core_1_1_sin-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_sin.html | 463 + docs/build/html/classmlx_1_1core_1_1_sin.png | Bin 0 -> 864 bytes .../classmlx_1_1core_1_1_sinh-members.html | 115 + .../build/html/classmlx_1_1core_1_1_sinh.html | 463 + docs/build/html/classmlx_1_1core_1_1_sinh.png | Bin 0 -> 870 bytes .../classmlx_1_1core_1_1_slice-members.html | 115 + .../html/classmlx_1_1core_1_1_slice.html | 447 + .../build/html/classmlx_1_1core_1_1_slice.png | Bin 0 -> 884 bytes ...smlx_1_1core_1_1_slice_update-members.html | 115 + .../classmlx_1_1core_1_1_slice_update.html | 447 + .../classmlx_1_1core_1_1_slice_update.png | Bin 0 -> 918 bytes .../classmlx_1_1core_1_1_softmax-members.html | 115 + .../html/classmlx_1_1core_1_1_softmax.html | 467 + .../html/classmlx_1_1core_1_1_softmax.png | Bin 0 -> 894 bytes .../classmlx_1_1core_1_1_sort-members.html | 115 + .../build/html/classmlx_1_1core_1_1_sort.html | 467 + docs/build/html/classmlx_1_1core_1_1_sort.png | Bin 0 -> 870 bytes .../classmlx_1_1core_1_1_split-members.html | 107 + .../html/classmlx_1_1core_1_1_split.html | 426 + .../build/html/classmlx_1_1core_1_1_split.png | Bin 0 -> 527 bytes .../classmlx_1_1core_1_1_sqrt-members.html | 115 + .../build/html/classmlx_1_1core_1_1_sqrt.html | 467 + docs/build/html/classmlx_1_1core_1_1_sqrt.png | Bin 0 -> 887 bytes .../classmlx_1_1core_1_1_square-members.html | 115 + .../html/classmlx_1_1core_1_1_square.html | 463 + .../html/classmlx_1_1core_1_1_square.png | Bin 0 -> 906 bytes ...mlx_1_1core_1_1_stop_gradient-members.html | 115 + .../classmlx_1_1core_1_1_stop_gradient.html | 382 + .../classmlx_1_1core_1_1_stop_gradient.png | Bin 0 -> 934 bytes ...classmlx_1_1core_1_1_subtract-members.html | 115 + .../html/classmlx_1_1core_1_1_subtract.html | 463 + .../html/classmlx_1_1core_1_1_subtract.png | Bin 0 -> 903 bytes .../classmlx_1_1core_1_1_tan-members.html | 115 + docs/build/html/classmlx_1_1core_1_1_tan.html | 463 + docs/build/html/classmlx_1_1core_1_1_tan.png | Bin 0 -> 875 bytes .../classmlx_1_1core_1_1_tanh-members.html | 115 + .../build/html/classmlx_1_1core_1_1_tanh.html | 463 + docs/build/html/classmlx_1_1core_1_1_tanh.png | Bin 0 -> 879 bytes ...lassmlx_1_1core_1_1_transpose-members.html | 115 + .../html/classmlx_1_1core_1_1_transpose.html | 437 + .../html/classmlx_1_1core_1_1_transpose.png | Bin 0 -> 914 bytes ...x_1_1core_1_1_unary_primitive-members.html | 114 + .../classmlx_1_1core_1_1_unary_primitive.html | 532 + .../classmlx_1_1core_1_1_unary_primitive.png | Bin 0 -> 31591 bytes .../classmlx_1_1core_1_1_uniform-members.html | 115 + .../html/classmlx_1_1core_1_1_uniform.html | 352 + .../html/classmlx_1_1core_1_1_uniform.png | Bin 0 -> 876 bytes ...re_1_1allocator_1_1_allocator-members.html | 98 + ...lx_1_1core_1_1allocator_1_1_allocator.html | 338 + ...mlx_1_1core_1_1allocator_1_1_allocator.png | Bin 0 -> 1087 bytes ...1core_1_1allocator_1_1_buffer-members.html | 94 + ...ssmlx_1_1core_1_1allocator_1_1_buffer.html | 201 + ...llocator_1_1_common_allocator-members.html | 99 + ...ore_1_1allocator_1_1_common_allocator.html | 219 + ...core_1_1allocator_1_1_common_allocator.png | Bin 0 -> 724 bytes .../classmlx_1_1core_1_1array-members.html | 159 + .../build/html/classmlx_1_1core_1_1array.html | 2000 ++ ...lx_1_1core_1_1fast_1_1_custom-members.html | 107 + .../classmlx_1_1core_1_1fast_1_1_custom.html | 306 + .../classmlx_1_1core_1_1fast_1_1_custom.png | Bin 0 -> 2664 bytes ..._1core_1_1fast_1_1_layer_norm-members.html | 109 + ...assmlx_1_1core_1_1fast_1_1_layer_norm.html | 327 + ...lassmlx_1_1core_1_1fast_1_1_layer_norm.png | Bin 0 -> 951 bytes ..._1_1fast_1_1_layer_norm_v_j_p-members.html | 109 + ..._1_1core_1_1fast_1_1_layer_norm_v_j_p.html | 284 + ...x_1_1core_1_1fast_1_1_layer_norm_v_j_p.png | Bin 0 -> 994 bytes ..._1core_1_1fast_1_1_r_m_s_norm-members.html | 109 + ...assmlx_1_1core_1_1fast_1_1_r_m_s_norm.html | 327 + ...lassmlx_1_1core_1_1fast_1_1_r_m_s_norm.png | Bin 0 -> 927 bytes ..._1_1fast_1_1_r_m_s_norm_v_j_p-members.html | 109 + ..._1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html | 284 + ...x_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.png | Bin 0 -> 981 bytes ...lx_1_1core_1_1fast_1_1_ro_p_e-members.html | 109 + .../classmlx_1_1core_1_1fast_1_1_ro_p_e.html | 352 + .../classmlx_1_1core_1_1fast_1_1_ro_p_e.png | Bin 0 -> 863 bytes ..._scaled_dot_product_attention-members.html | 110 + ...fast_1_1_scaled_dot_product_attention.html | 333 + ...1fast_1_1_scaled_dot_product_attention.png | Bin 0 -> 1075 bytes ...1_1core_1_1io_1_1_file_reader-members.html | 98 + ...lassmlx_1_1core_1_1io_1_1_file_reader.html | 346 + ...classmlx_1_1core_1_1io_1_1_file_reader.png | Bin 0 -> 642 bytes ...1_1core_1_1io_1_1_file_writer-members.html | 98 + ...lassmlx_1_1core_1_1io_1_1_file_writer.html | 346 + ...classmlx_1_1core_1_1io_1_1_file_writer.png | Bin 0 -> 612 bytes ...smlx_1_1core_1_1io_1_1_reader-members.html | 96 + .../classmlx_1_1core_1_1io_1_1_reader.html | 291 + .../classmlx_1_1core_1_1io_1_1_reader.png | Bin 0 -> 647 bytes ...smlx_1_1core_1_1io_1_1_writer-members.html | 96 + .../classmlx_1_1core_1_1io_1_1_writer.html | 291 + .../classmlx_1_1core_1_1io_1_1_writer.png | Bin 0 -> 619 bytes ...x_1_1core_1_1metal_1_1_device-members.html | 112 + .../classmlx_1_1core_1_1metal_1_1_device.html | 635 + ..._1_1metal_1_1_metal_allocator-members.html | 106 + ..._1_1core_1_1metal_1_1_metal_allocator.html | 388 + ...x_1_1core_1_1metal_1_1_metal_allocator.png | Bin 0 -> 680 bytes ...re_1_1random_1_1_key_sequence-members.html | 94 + ...lx_1_1core_1_1random_1_1_key_sequence.html | 197 + ...re_1_1scheduler_1_1_scheduler-members.html | 104 + ...lx_1_1core_1_1scheduler_1_1_scheduler.html | 478 + ...etfft_1_1detail_1_1_t__dcst23-members.html | 93 + ...lasspocketfft_1_1detail_1_1_t__dcst23.html | 210 + ...ketfft_1_1detail_1_1_t__dcst4-members.html | 93 + ...classpocketfft_1_1detail_1_1_t__dcst4.html | 210 + ...cketfft_1_1detail_1_1_t__dct1-members.html | 93 + .../classpocketfft_1_1detail_1_1_t__dct1.html | 210 + ...cketfft_1_1detail_1_1_t__dst1-members.html | 93 + .../classpocketfft_1_1detail_1_1_t__dst1.html | 210 + ...asspocketfft_1_1detail_1_1arr-members.html | 100 + .../html/classpocketfft_1_1detail_1_1arr.html | 391 + ...ketfft_1_1detail_1_1arr__info-members.html | 99 + ...classpocketfft_1_1detail_1_1arr__info.html | 357 + .../classpocketfft_1_1detail_1_1arr__info.png | Bin 0 -> 1040 bytes ...spocketfft_1_1detail_1_1cfftp-members.html | 92 + .../classpocketfft_1_1detail_1_1cfftp.html | 171 + ...pocketfft_1_1detail_1_1cndarr-members.html | 102 + .../classpocketfft_1_1detail_1_1cndarr.html | 229 + .../classpocketfft_1_1detail_1_1cndarr.png | Bin 0 -> 1037 bytes ...ocketfft_1_1detail_1_1fftblue-members.html | 93 + .../classpocketfft_1_1detail_1_1fftblue.html | 212 + ...tfft_1_1detail_1_1multi__iter-members.html | 101 + ...asspocketfft_1_1detail_1_1multi__iter.html | 437 + ...spocketfft_1_1detail_1_1ndarr-members.html | 104 + .../classpocketfft_1_1detail_1_1ndarr.html | 209 + .../classpocketfft_1_1detail_1_1ndarr.png | Bin 0 -> 1032 bytes ...fft_1_1detail_1_1pocketfft__c-members.html | 93 + ...sspocketfft_1_1detail_1_1pocketfft__c.html | 200 + ...fft_1_1detail_1_1pocketfft__r-members.html | 93 + ...sspocketfft_1_1detail_1_1pocketfft__r.html | 200 + ...ketfft_1_1detail_1_1rev__iter-members.html | 95 + ...classpocketfft_1_1detail_1_1rev__iter.html | 240 + ...spocketfft_1_1detail_1_1rfftp-members.html | 92 + .../classpocketfft_1_1detail_1_1rfftp.html | 171 + ...fft_1_1detail_1_1simple__iter-members.html | 94 + ...sspocketfft_1_1detail_1_1simple__iter.html | 209 + ...t_1_1detail_1_1sincos__2pibyn-members.html | 92 + ...pocketfft_1_1detail_1_1sincos__2pibyn.html | 159 + ...hreading_1_1concurrent__queue-members.html | 93 + ...ail_1_1threading_1_1concurrent__queue.html | 187 + ...1detail_1_1threading_1_1latch-members.html | 94 + ...etfft_1_1detail_1_1threading_1_1latch.html | 209 + ..._1_1threading_1_1thread__pool-members.html | 96 + ..._1detail_1_1threading_1_1thread__pool.html | 263 + docs/build/html/clipboard.js | 61 + docs/build/html/closed.png | Bin 0 -> 132 bytes docs/build/html/common_2binary_8h.html | 117 + docs/build/html/common_2binary_8h_source.html | 764 + .../html/common_2compiled__preamble_8h.html | 118 + .../common_2compiled__preamble_8h_source.html | 107 + docs/build/html/common_2copy_8h.html | 122 + docs/build/html/common_2copy_8h_source.html | 145 + docs/build/html/common_2reduce_8h.html | 136 + docs/build/html/common_2reduce_8h_source.html | 483 + docs/build/html/common_2ternary_8h.html | 103 + .../build/html/common_2ternary_8h_source.html | 327 + docs/build/html/common_2unary_8h.html | 103 + docs/build/html/common_2unary_8h_source.html | 229 + docs/build/html/compile_8h.html | 130 + docs/build/html/compile_8h_source.html | 123 + docs/build/html/compile__impl_8h.html | 108 + docs/build/html/compile__impl_8h_source.html | 107 + docs/build/html/compiled_8h.html | 131 + docs/build/html/compiled_8h_source.html | 195 + docs/build/html/conv_2loader_8h.html | 91 + docs/build/html/conv_2loader_8h_source.html | 100 + docs/build/html/conv_2params_8h.html | 111 + docs/build/html/conv_2params_8h_source.html | 202 + docs/build/html/conv_8h.html | 92 + docs/build/html/conv_8h_source.html | 108 + docs/build/html/cookie.js | 58 + docs/build/html/cpp/ops.html | 2234 ++- docs/build/html/defines_8h.html | 309 + docs/build/html/defines_8h_source.html | 119 + docs/build/html/dev/extensions.html | 23 +- docs/build/html/dev/metal_debugger.html | 23 +- docs/build/html/device_8h.html | 117 + docs/build/html/device_8h_source.html | 139 + .../dir_1683daa6c50d5a1449f58a10604f9f12.html | 93 + .../dir_1d446c9bd3c99228254c9484e0bc5c06.html | 97 + .../dir_2193406f5b2eae6fc53753d8a9a80df3.html | 95 + .../dir_47795aa8999234f6f402f7e89d34d08e.html | 99 + .../dir_6768c99e6145fb9510ccdb40db8ede25.html | 101 + .../dir_70a37effa88bcbd6b791977fa1e64356.html | 126 + .../dir_76215a6c54e2b67053e723fc2395583c.html | 100 + .../dir_86b95e7b1d0d6e25466bb9213752d32f.html | 93 + .../dir_938ab0ecf10b8b860ff766c820f665fd.html | 146 + .../dir_ad00dcd1517bfdbe01f68ec9b4eff877.html | 93 + .../dir_ba4426224ef60f409462a2a12fa18f06.html | 97 + .../dir_d0c977ea65824390717cdb7efc36c157.html | 116 + .../dir_df9494e83ef22ae6150a0e080d9709ed.html | 102 + .../dir_f149b24a1b5be11cd70151abe517e3f8.html | 117 + .../dir_f60cd69d27fd3faa641c79056fff0e2d.html | 97 + docs/build/html/doc.svg | 12 + docs/build/html/docd.svg | 12 + docs/build/html/doxygen.css | 2225 +++ docs/build/html/doxygen.svg | 28 + docs/build/html/doxygen_crawl.html | 1133 ++ docs/build/html/dtype_8h.html | 181 + docs/build/html/dtype_8h_source.html | 282 + docs/build/html/dynsections.js | 194 + docs/build/html/erf_8h.html | 135 + docs/build/html/erf_8h_source.html | 173 + docs/build/html/event_8h.html | 108 + docs/build/html/event_8h_source.html | 174 + .../html/examples/linear_regression.html | 23 +- docs/build/html/examples/llama-inference.html | 23 +- docs/build/html/examples/mlp.html | 23 +- docs/build/html/expm1f_8h.html | 139 + docs/build/html/expm1f_8h_source.html | 189 + docs/build/html/fast_8h.html | 116 + docs/build/html/fast_8h_source.html | 143 + docs/build/html/fast__primitives_8h.html | 120 + .../html/fast__primitives_8h_source.html | 396 + docs/build/html/fft_8h.html | 177 + docs/build/html/fft_8h_source.html | 281 + docs/build/html/files.html | 184 + docs/build/html/folderclosed.svg | 11 + docs/build/html/folderclosedd.svg | 11 + docs/build/html/folderopen.svg | 17 + docs/build/html/folderopend.svg | 12 + docs/build/html/fp16_8h.html | 741 + docs/build/html/fp16_8h_source.html | 369 + docs/build/html/functions.html | 86 + docs/build/html/functions_a.html | 126 + docs/build/html/functions_b.html | 112 + docs/build/html/functions_c.html | 121 + docs/build/html/functions_d.html | 111 + docs/build/html/functions_e.html | 104 + docs/build/html/functions_enum.html | 90 + docs/build/html/functions_eval.html | 100 + docs/build/html/functions_f.html | 102 + docs/build/html/functions_func.html | 86 + docs/build/html/functions_func_a.html | 113 + docs/build/html/functions_func_b.html | 94 + docs/build/html/functions_func_c.html | 113 + docs/build/html/functions_func_d.html | 103 + docs/build/html/functions_func_e.html | 102 + docs/build/html/functions_func_f.html | 92 + docs/build/html/functions_func_g.html | 102 + docs/build/html/functions_func_h.html | 85 + docs/build/html/functions_func_i.html | 101 + docs/build/html/functions_func_j.html | 85 + docs/build/html/functions_func_k.html | 85 + docs/build/html/functions_func_l.html | 105 + docs/build/html/functions_func_m.html | 97 + docs/build/html/functions_func_n.html | 96 + docs/build/html/functions_func_o.html | 104 + docs/build/html/functions_func_p.html | 97 + docs/build/html/functions_func_q.html | 87 + docs/build/html/functions_func_r.html | 107 + docs/build/html/functions_func_s.html | 146 + docs/build/html/functions_func_t.html | 98 + docs/build/html/functions_func_u.html | 87 + docs/build/html/functions_func_v.html | 89 + docs/build/html/functions_func_w.html | 87 + docs/build/html/functions_func_~.html | 98 + docs/build/html/functions_g.html | 107 + docs/build/html/functions_h.html | 85 + docs/build/html/functions_i.html | 112 + docs/build/html/functions_j.html | 88 + docs/build/html/functions_k.html | 89 + docs/build/html/functions_l.html | 112 + docs/build/html/functions_m.html | 110 + docs/build/html/functions_n.html | 103 + docs/build/html/functions_o.html | 112 + docs/build/html/functions_p.html | 100 + docs/build/html/functions_q.html | 89 + docs/build/html/functions_r.html | 119 + docs/build/html/functions_rela.html | 85 + docs/build/html/functions_s.html | 164 + docs/build/html/functions_t.html | 122 + docs/build/html/functions_type.html | 92 + docs/build/html/functions_u.html | 88 + docs/build/html/functions_v.html | 95 + docs/build/html/functions_vars.html | 91 + docs/build/html/functions_vars_b.html | 102 + docs/build/html/functions_vars_c.html | 91 + docs/build/html/functions_vars_d.html | 91 + docs/build/html/functions_vars_e.html | 85 + docs/build/html/functions_vars_f.html | 94 + docs/build/html/functions_vars_g.html | 89 + docs/build/html/functions_vars_i.html | 95 + docs/build/html/functions_vars_j.html | 87 + docs/build/html/functions_vars_k.html | 87 + docs/build/html/functions_vars_l.html | 88 + docs/build/html/functions_vars_m.html | 95 + docs/build/html/functions_vars_n.html | 91 + docs/build/html/functions_vars_o.html | 88 + docs/build/html/functions_vars_p.html | 86 + docs/build/html/functions_vars_q.html | 86 + docs/build/html/functions_vars_r.html | 93 + docs/build/html/functions_vars_s.html | 103 + docs/build/html/functions_vars_t.html | 106 + docs/build/html/functions_vars_v.html | 87 + docs/build/html/functions_vars_w.html | 91 + docs/build/html/functions_w.html | 94 + docs/build/html/functions_x.html | 85 + docs/build/html/functions_~.html | 98 + docs/build/html/gemm_2loader_8h.html | 108 + docs/build/html/gemm_2loader_8h_source.html | 247 + docs/build/html/gemm_2params_8h.html | 109 + docs/build/html/gemm_2params_8h_source.html | 195 + docs/build/html/genindex.html | 449 +- docs/build/html/gguf_8h.html | 112 + docs/build/html/gguf_8h_source.html | 119 + docs/build/html/globals.html | 88 + docs/build/html/globals_a.html | 85 + docs/build/html/globals_b.html | 97 + docs/build/html/globals_c.html | 91 + docs/build/html/globals_d.html | 89 + docs/build/html/globals_defs.html | 220 + docs/build/html/globals_e.html | 96 + docs/build/html/globals_f.html | 88 + docs/build/html/globals_func.html | 85 + docs/build/html/globals_func_c.html | 85 + docs/build/html/globals_func_e.html | 96 + docs/build/html/globals_func_f.html | 85 + docs/build/html/globals_func_g.html | 85 + docs/build/html/globals_func_l.html | 85 + docs/build/html/globals_func_m.html | 95 + docs/build/html/globals_func_o.html | 100 + docs/build/html/globals_func_s.html | 85 + docs/build/html/globals_g.html | 85 + docs/build/html/globals_h.html | 91 + docs/build/html/globals_i.html | 122 + docs/build/html/globals_l.html | 85 + docs/build/html/globals_m.html | 101 + docs/build/html/globals_n.html | 85 + docs/build/html/globals_o.html | 101 + docs/build/html/globals_p.html | 104 + docs/build/html/globals_r.html | 87 + docs/build/html/globals_s.html | 90 + docs/build/html/globals_type.html | 84 + docs/build/html/globals_u.html | 86 + docs/build/html/globals_vars.html | 102 + docs/build/html/graph__utils_8h.html | 120 + docs/build/html/graph__utils_8h_source.html | 136 + docs/build/html/group__ops.html | 9191 +++++++++ docs/build/html/half__types_8h.html | 197 + docs/build/html/half__types_8h_source.html | 158 + docs/build/html/hierarchy.html | 441 + docs/build/html/index.html | 23 +- docs/build/html/indexing_8h.html | 594 + docs/build/html/indexing_8h_source.html | 155 + docs/build/html/install.html | 29 +- docs/build/html/io_8h.html | 144 + docs/build/html/io_8h_source.html | 161 + docs/build/html/jquery.js | 34 + .../html/kernels_2steel_2gemm_2gemm_8h.html | 112 + .../kernels_2steel_2gemm_2gemm_8h_source.html | 413 + docs/build/html/lapack__helper_8h.html | 116 + docs/build/html/lapack__helper_8h_source.html | 115 + docs/build/html/linalg_8h.html | 129 + docs/build/html/linalg_8h_source.html | 169 + docs/build/html/load_8h.html | 116 + docs/build/html/load_8h_source.html | 282 + docs/build/html/loader__channel__l_8h.html | 111 + .../html/loader__channel__l_8h_source.html | 653 + docs/build/html/loader__channel__n_8h.html | 119 + .../html/loader__channel__n_8h_source.html | 499 + docs/build/html/loader__general_8h.html | 109 + .../build/html/loader__general_8h_source.html | 465 + docs/build/html/matmul_8h.html | 115 + docs/build/html/matmul_8h_source.html | 154 + docs/build/html/menu.js | 134 + docs/build/html/menudata.js | 225 + .../html/metal_2compiled__preamble_8h.html | 107 + .../metal_2compiled__preamble_8h_source.html | 103 + docs/build/html/metal_2copy_8h.html | 116 + docs/build/html/metal_2copy_8h_source.html | 140 + .../build/html/metal_2kernels_2binary_8h.html | 147 + .../metal_2kernels_2binary_8h_source.html | 559 + ...metal_2kernels_2compiled__preamble_8h.html | 115 + ...kernels_2compiled__preamble_8h_source.html | 103 + .../html/metal_2kernels_2ternary_8h.html | 97 + .../metal_2kernels_2ternary_8h_source.html | 108 + docs/build/html/metal_2kernels_2unary_8h.html | 165 + .../html/metal_2kernels_2unary_8h_source.html | 811 + docs/build/html/metal_2reduce_8h.html | 112 + docs/build/html/metal_2reduce_8h_source.html | 145 + docs/build/html/metal_8h.html | 131 + docs/build/html/metal_8h_source.html | 177 + docs/build/html/metal__impl_8h.html | 121 + docs/build/html/metal__impl_8h_source.html | 124 + docs/build/html/minus.svg | 8 + docs/build/html/minusd.svg | 8 + docs/build/html/mlx_8h.html | 102 + docs/build/html/mlx_8h_source.html | 122 + docs/build/html/mma_8h.html | 110 + docs/build/html/mma_8h_source.html | 418 + docs/build/html/mps_2gemm_8h.html | 235 + docs/build/html/mps_2gemm_8h_source.html | 572 + docs/build/html/namespace_m_p_s.html | 160 + docs/build/html/namespace_m_t_l.html | 91 + .../html/namespace_m_t_l_1_1_private.html | 97 + ...namespace_m_t_l_1_1_private_1_1_class.html | 227 + ...espace_m_t_l_1_1_private_1_1_selector.html | 439 + docs/build/html/namespacemembers.html | 86 + docs/build/html/namespacemembers_a.html | 122 + docs/build/html/namespacemembers_b.html | 99 + docs/build/html/namespacemembers_c.html | 129 + docs/build/html/namespacemembers_d.html | 107 + docs/build/html/namespacemembers_e.html | 99 + docs/build/html/namespacemembers_enum.html | 86 + docs/build/html/namespacemembers_eval.html | 93 + docs/build/html/namespacemembers_f.html | 111 + docs/build/html/namespacemembers_func.html | 86 + docs/build/html/namespacemembers_func_a.html | 121 + docs/build/html/namespacemembers_func_b.html | 95 + docs/build/html/namespacemembers_func_c.html | 117 + docs/build/html/namespacemembers_func_d.html | 100 + docs/build/html/namespacemembers_func_e.html | 98 + docs/build/html/namespacemembers_func_f.html | 106 + docs/build/html/namespacemembers_func_g.html | 102 + docs/build/html/namespacemembers_func_i.html | 104 + docs/build/html/namespacemembers_func_j.html | 85 + docs/build/html/namespacemembers_func_k.html | 86 + docs/build/html/namespacemembers_func_l.html | 102 + docs/build/html/namespacemembers_func_m.html | 102 + docs/build/html/namespacemembers_func_n.html | 96 + docs/build/html/namespacemembers_func_o.html | 112 + docs/build/html/namespacemembers_func_p.html | 98 + docs/build/html/namespacemembers_func_q.html | 87 + docs/build/html/namespacemembers_func_r.html | 108 + docs/build/html/namespacemembers_func_s.html | 146 + docs/build/html/namespacemembers_func_t.html | 103 + docs/build/html/namespacemembers_func_u.html | 85 + docs/build/html/namespacemembers_func_v.html | 90 + docs/build/html/namespacemembers_func_w.html | 86 + docs/build/html/namespacemembers_func_z.html | 86 + docs/build/html/namespacemembers_g.html | 109 + docs/build/html/namespacemembers_i.html | 112 + docs/build/html/namespacemembers_j.html | 85 + docs/build/html/namespacemembers_k.html | 86 + docs/build/html/namespacemembers_l.html | 102 + docs/build/html/namespacemembers_m.html | 104 + docs/build/html/namespacemembers_n.html | 98 + docs/build/html/namespacemembers_o.html | 112 + docs/build/html/namespacemembers_p.html | 98 + docs/build/html/namespacemembers_q.html | 87 + docs/build/html/namespacemembers_r.html | 109 + docs/build/html/namespacemembers_s.html | 152 + docs/build/html/namespacemembers_t.html | 103 + docs/build/html/namespacemembers_type.html | 98 + docs/build/html/namespacemembers_u.html | 90 + docs/build/html/namespacemembers_v.html | 92 + docs/build/html/namespacemembers_vars.html | 156 + docs/build/html/namespacemembers_w.html | 86 + docs/build/html/namespacemembers_z.html | 86 + docs/build/html/namespacemetal.html | 1661 ++ docs/build/html/namespacemetal_1_1fast.html | 1178 ++ .../build/html/namespacemetal_1_1precise.html | 1178 ++ docs/build/html/namespacemlx.html | 93 + docs/build/html/namespacemlx_1_1core.html | 16307 ++++++++++++++++ .../namespacemlx_1_1core_1_1allocator.html | 180 + .../html/namespacemlx_1_1core_1_1detail.html | 439 + .../html/namespacemlx_1_1core_1_1fast.html | 277 + .../html/namespacemlx_1_1core_1_1fft.html | 1082 + .../html/namespacemlx_1_1core_1_1io.html | 101 + .../html/namespacemlx_1_1core_1_1linalg.html | 414 + .../html/namespacemlx_1_1core_1_1metal.html | 519 + .../html/namespacemlx_1_1core_1_1random.html | 1149 ++ .../namespacemlx_1_1core_1_1scheduler.html | 255 + docs/build/html/namespacemlx_1_1steel.html | 147 + docs/build/html/namespacepocketfft.html | 91 + .../html/namespacepocketfft_1_1detail.html | 1725 ++ ...spacepocketfft_1_1detail_1_1threading.html | 245 + docs/build/html/namespaces.html | 367 + docs/build/html/nav_f.png | Bin 0 -> 153 bytes docs/build/html/nav_fd.png | Bin 0 -> 169 bytes docs/build/html/nav_g.png | Bin 0 -> 95 bytes docs/build/html/nav_h.png | Bin 0 -> 98 bytes docs/build/html/nav_hd.png | Bin 0 -> 114 bytes docs/build/html/objects.inv | Bin 9826 -> 23657 bytes docs/build/html/open.png | Bin 0 -> 123 bytes docs/build/html/ops_8h.html | 879 + docs/build/html/ops_8h_source.html | 1432 ++ docs/build/html/plus.svg | 9 + docs/build/html/plusd.svg | 9 + docs/build/html/pocketfft_8h.html | 1134 ++ docs/build/html/pocketfft_8h_source.html | 4170 ++++ docs/build/html/primitives_8h.html | 417 + docs/build/html/primitives_8h_source.html | 2872 +++ .../python/_autosummary/mlx.core.Device.html | 23 +- .../python/_autosummary/mlx.core.Dtype.html | 23 +- .../_autosummary/mlx.core.DtypeCategory.html | 23 +- .../python/_autosummary/mlx.core.Stream.html | 23 +- .../python/_autosummary/mlx.core.abs.html | 23 +- .../python/_autosummary/mlx.core.add.html | 23 +- .../python/_autosummary/mlx.core.all.html | 23 +- .../_autosummary/mlx.core.allclose.html | 23 +- .../python/_autosummary/mlx.core.any.html | 23 +- .../python/_autosummary/mlx.core.arange.html | 23 +- .../python/_autosummary/mlx.core.arccos.html | 23 +- .../python/_autosummary/mlx.core.arccosh.html | 23 +- .../python/_autosummary/mlx.core.arcsin.html | 23 +- .../python/_autosummary/mlx.core.arcsinh.html | 23 +- .../python/_autosummary/mlx.core.arctan.html | 29 +- .../python/_autosummary/mlx.core.arctan2.html | 925 + .../python/_autosummary/mlx.core.arctanh.html | 29 +- .../python/_autosummary/mlx.core.argmax.html | 23 +- .../python/_autosummary/mlx.core.argmin.html | 23 +- .../_autosummary/mlx.core.argpartition.html | 23 +- .../python/_autosummary/mlx.core.argsort.html | 23 +- .../python/_autosummary/mlx.core.array.T.html | 23 +- .../_autosummary/mlx.core.array.abs.html | 23 +- .../_autosummary/mlx.core.array.all.html | 23 +- .../_autosummary/mlx.core.array.any.html | 23 +- .../_autosummary/mlx.core.array.argmax.html | 23 +- .../_autosummary/mlx.core.array.argmin.html | 23 +- .../_autosummary/mlx.core.array.astype.html | 23 +- .../_autosummary/mlx.core.array.at.html | 23 +- .../_autosummary/mlx.core.array.cos.html | 23 +- .../_autosummary/mlx.core.array.cummax.html | 23 +- .../_autosummary/mlx.core.array.cummin.html | 23 +- .../_autosummary/mlx.core.array.cumprod.html | 23 +- .../_autosummary/mlx.core.array.cumsum.html | 23 +- .../_autosummary/mlx.core.array.diag.html | 23 +- .../_autosummary/mlx.core.array.diagonal.html | 23 +- .../_autosummary/mlx.core.array.dtype.html | 23 +- .../_autosummary/mlx.core.array.exp.html | 23 +- .../_autosummary/mlx.core.array.flatten.html | 23 +- .../python/_autosummary/mlx.core.array.html | 94 +- .../_autosummary/mlx.core.array.item.html | 23 +- .../_autosummary/mlx.core.array.itemsize.html | 23 +- .../_autosummary/mlx.core.array.log.html | 23 +- .../_autosummary/mlx.core.array.log10.html | 23 +- .../_autosummary/mlx.core.array.log1p.html | 23 +- .../_autosummary/mlx.core.array.log2.html | 23 +- .../mlx.core.array.logsumexp.html | 23 +- .../_autosummary/mlx.core.array.max.html | 23 +- .../_autosummary/mlx.core.array.mean.html | 23 +- .../_autosummary/mlx.core.array.min.html | 23 +- .../_autosummary/mlx.core.array.moveaxis.html | 23 +- .../_autosummary/mlx.core.array.nbytes.html | 23 +- .../_autosummary/mlx.core.array.ndim.html | 23 +- .../_autosummary/mlx.core.array.prod.html | 23 +- .../mlx.core.array.reciprocal.html | 23 +- .../_autosummary/mlx.core.array.reshape.html | 23 +- .../_autosummary/mlx.core.array.round.html | 23 +- .../_autosummary/mlx.core.array.rsqrt.html | 23 +- .../_autosummary/mlx.core.array.shape.html | 23 +- .../_autosummary/mlx.core.array.sin.html | 23 +- .../_autosummary/mlx.core.array.size.html | 23 +- .../_autosummary/mlx.core.array.split.html | 23 +- .../_autosummary/mlx.core.array.sqrt.html | 23 +- .../_autosummary/mlx.core.array.square.html | 23 +- .../_autosummary/mlx.core.array.squeeze.html | 23 +- .../_autosummary/mlx.core.array.sum.html | 23 +- .../_autosummary/mlx.core.array.swapaxes.html | 23 +- .../_autosummary/mlx.core.array.tolist.html | 23 +- .../mlx.core.array.transpose.html | 23 +- .../_autosummary/mlx.core.array.var.html | 23 +- .../_autosummary/mlx.core.array_equal.html | 23 +- .../_autosummary/mlx.core.atleast_1d.html | 23 +- .../_autosummary/mlx.core.atleast_2d.html | 23 +- .../_autosummary/mlx.core.atleast_3d.html | 29 +- .../_autosummary/mlx.core.bitwise_and.html | 927 + .../_autosummary/mlx.core.bitwise_or.html | 927 + .../_autosummary/mlx.core.bitwise_xor.html | 928 + .../mlx.core.block_masked_mm.html | 35 +- .../mlx.core.block_sparse_mm.html | 928 + .../_autosummary/mlx.core.broadcast_to.html | 35 +- .../python/_autosummary/mlx.core.ceil.html | 29 +- .../python/_autosummary/mlx.core.clip.html | 23 +- .../python/_autosummary/mlx.core.compile.html | 23 +- .../_autosummary/mlx.core.concatenate.html | 29 +- .../python/_autosummary/mlx.core.conj.html | 917 + .../_autosummary/mlx.core.conjugate.html | 917 + .../python/_autosummary/mlx.core.conv1d.html | 23 +- .../python/_autosummary/mlx.core.conv2d.html | 23 +- .../_autosummary/mlx.core.conv_general.html | 23 +- .../_autosummary/mlx.core.convolve.html | 29 +- .../python/_autosummary/mlx.core.cos.html | 23 +- .../python/_autosummary/mlx.core.cosh.html | 23 +- .../python/_autosummary/mlx.core.cummax.html | 23 +- .../python/_autosummary/mlx.core.cummin.html | 23 +- .../python/_autosummary/mlx.core.cumprod.html | 23 +- .../python/_autosummary/mlx.core.cumsum.html | 23 +- .../_autosummary/mlx.core.default_device.html | 23 +- .../_autosummary/mlx.core.default_stream.html | 23 +- .../python/_autosummary/mlx.core.degrees.html | 23 +- .../_autosummary/mlx.core.dequantize.html | 23 +- .../python/_autosummary/mlx.core.diag.html | 23 +- .../_autosummary/mlx.core.diagonal.html | 23 +- .../mlx.core.disable_compile.html | 23 +- .../python/_autosummary/mlx.core.divide.html | 23 +- .../python/_autosummary/mlx.core.divmod.html | 23 +- .../_autosummary/mlx.core.enable_compile.html | 23 +- .../python/_autosummary/mlx.core.equal.html | 23 +- .../python/_autosummary/mlx.core.erf.html | 23 +- .../python/_autosummary/mlx.core.erfinv.html | 23 +- .../python/_autosummary/mlx.core.eval.html | 23 +- .../python/_autosummary/mlx.core.exp.html | 23 +- .../_autosummary/mlx.core.expand_dims.html | 23 +- .../python/_autosummary/mlx.core.expm1.html | 23 +- .../python/_autosummary/mlx.core.eye.html | 23 +- .../mlx.core.fast.layer_norm.html | 23 +- .../_autosummary/mlx.core.fast.rms_norm.html | 23 +- .../_autosummary/mlx.core.fast.rope.html | 23 +- ...ore.fast.scaled_dot_product_attention.html | 23 +- .../python/_autosummary/mlx.core.fft.fft.html | 23 +- .../_autosummary/mlx.core.fft.fft2.html | 23 +- .../_autosummary/mlx.core.fft.fftn.html | 23 +- .../_autosummary/mlx.core.fft.ifft.html | 23 +- .../_autosummary/mlx.core.fft.ifft2.html | 23 +- .../_autosummary/mlx.core.fft.ifftn.html | 23 +- .../_autosummary/mlx.core.fft.irfft.html | 23 +- .../_autosummary/mlx.core.fft.irfft2.html | 23 +- .../_autosummary/mlx.core.fft.irfftn.html | 23 +- .../_autosummary/mlx.core.fft.rfft.html | 23 +- .../_autosummary/mlx.core.fft.rfft2.html | 23 +- .../_autosummary/mlx.core.fft.rfftn.html | 23 +- .../python/_autosummary/mlx.core.flatten.html | 23 +- .../python/_autosummary/mlx.core.floor.html | 23 +- .../_autosummary/mlx.core.floor_divide.html | 23 +- .../python/_autosummary/mlx.core.full.html | 23 +- .../python/_autosummary/mlx.core.grad.html | 23 +- .../python/_autosummary/mlx.core.greater.html | 23 +- .../_autosummary/mlx.core.greater_equal.html | 23 +- .../_autosummary/mlx.core.identity.html | 23 +- .../python/_autosummary/mlx.core.inner.html | 23 +- .../python/_autosummary/mlx.core.isclose.html | 23 +- .../python/_autosummary/mlx.core.isinf.html | 23 +- .../python/_autosummary/mlx.core.isnan.html | 23 +- .../_autosummary/mlx.core.isneginf.html | 23 +- .../_autosummary/mlx.core.isposinf.html | 29 +- .../_autosummary/mlx.core.issubdtype.html | 23 +- .../python/_autosummary/mlx.core.jvp.html | 23 +- .../_autosummary/mlx.core.left_shift.html | 928 + .../python/_autosummary/mlx.core.less.html | 29 +- .../_autosummary/mlx.core.less_equal.html | 23 +- .../_autosummary/mlx.core.linalg.norm.html | 23 +- .../_autosummary/mlx.core.linalg.qr.html | 23 +- .../_autosummary/mlx.core.linspace.html | 23 +- .../python/_autosummary/mlx.core.load.html | 23 +- .../python/_autosummary/mlx.core.log.html | 23 +- .../python/_autosummary/mlx.core.log10.html | 23 +- .../python/_autosummary/mlx.core.log1p.html | 23 +- .../python/_autosummary/mlx.core.log2.html | 23 +- .../_autosummary/mlx.core.logaddexp.html | 23 +- .../_autosummary/mlx.core.logical_and.html | 23 +- .../_autosummary/mlx.core.logical_not.html | 23 +- .../_autosummary/mlx.core.logical_or.html | 23 +- .../_autosummary/mlx.core.logsumexp.html | 23 +- .../python/_autosummary/mlx.core.matmul.html | 23 +- .../python/_autosummary/mlx.core.max.html | 23 +- .../python/_autosummary/mlx.core.maximum.html | 23 +- .../python/_autosummary/mlx.core.mean.html | 23 +- .../_autosummary/mlx.core.meshgrid.html | 23 +- .../mlx.core.metal.clear_cache.html | 23 +- .../mlx.core.metal.device_info.html | 926 + .../mlx.core.metal.get_active_memory.html | 29 +- .../mlx.core.metal.get_cache_memory.html | 29 +- .../mlx.core.metal.get_peak_memory.html | 33 +- .../mlx.core.metal.is_available.html | 29 +- .../mlx.core.metal.reset_peak_memory.html | 911 + .../mlx.core.metal.set_cache_limit.html | 23 +- .../mlx.core.metal.set_memory_limit.html | 23 +- .../mlx.core.metal.start_capture.html | 23 +- .../mlx.core.metal.stop_capture.html | 23 +- .../python/_autosummary/mlx.core.min.html | 23 +- .../python/_autosummary/mlx.core.minimum.html | 23 +- .../_autosummary/mlx.core.moveaxis.html | 23 +- .../_autosummary/mlx.core.multiply.html | 23 +- .../_autosummary/mlx.core.negative.html | 23 +- .../_autosummary/mlx.core.new_stream.html | 23 +- .../_autosummary/mlx.core.not_equal.html | 23 +- .../python/_autosummary/mlx.core.ones.html | 23 +- .../_autosummary/mlx.core.ones_like.html | 23 +- .../python/_autosummary/mlx.core.outer.html | 23 +- .../python/_autosummary/mlx.core.pad.html | 23 +- .../_autosummary/mlx.core.partition.html | 23 +- .../python/_autosummary/mlx.core.prod.html | 23 +- .../_autosummary/mlx.core.quantize.html | 23 +- .../mlx.core.quantized_matmul.html | 23 +- .../python/_autosummary/mlx.core.radians.html | 23 +- .../mlx.core.random.bernoulli.html | 23 +- .../mlx.core.random.categorical.html | 23 +- .../_autosummary/mlx.core.random.gumbel.html | 23 +- .../_autosummary/mlx.core.random.key.html | 23 +- .../mlx.core.random.multivariate_normal.html | 23 +- .../_autosummary/mlx.core.random.normal.html | 23 +- .../_autosummary/mlx.core.random.randint.html | 23 +- .../_autosummary/mlx.core.random.seed.html | 23 +- .../_autosummary/mlx.core.random.split.html | 23 +- .../mlx.core.random.truncated_normal.html | 23 +- .../_autosummary/mlx.core.random.uniform.html | 23 +- .../_autosummary/mlx.core.reciprocal.html | 23 +- .../python/_autosummary/mlx.core.repeat.html | 23 +- .../python/_autosummary/mlx.core.reshape.html | 29 +- .../_autosummary/mlx.core.right_shift.html | 928 + .../python/_autosummary/mlx.core.round.html | 29 +- .../python/_autosummary/mlx.core.rsqrt.html | 23 +- .../python/_autosummary/mlx.core.save.html | 23 +- .../_autosummary/mlx.core.save_gguf.html | 23 +- .../mlx.core.save_safetensors.html | 23 +- .../python/_autosummary/mlx.core.savez.html | 23 +- .../mlx.core.savez_compressed.html | 23 +- .../mlx.core.set_default_device.html | 23 +- .../mlx.core.set_default_stream.html | 23 +- .../python/_autosummary/mlx.core.sigmoid.html | 23 +- .../python/_autosummary/mlx.core.sign.html | 23 +- .../python/_autosummary/mlx.core.sin.html | 23 +- .../python/_autosummary/mlx.core.sinh.html | 23 +- .../python/_autosummary/mlx.core.softmax.html | 23 +- .../python/_autosummary/mlx.core.sort.html | 23 +- .../python/_autosummary/mlx.core.split.html | 23 +- .../python/_autosummary/mlx.core.sqrt.html | 23 +- .../python/_autosummary/mlx.core.square.html | 23 +- .../python/_autosummary/mlx.core.squeeze.html | 23 +- .../python/_autosummary/mlx.core.stack.html | 23 +- .../python/_autosummary/mlx.core.std.html | 23 +- .../_autosummary/mlx.core.stop_gradient.html | 23 +- .../_autosummary/mlx.core.subtract.html | 23 +- .../python/_autosummary/mlx.core.sum.html | 23 +- .../_autosummary/mlx.core.swapaxes.html | 23 +- .../_autosummary/mlx.core.synchronize.html | 23 +- .../python/_autosummary/mlx.core.take.html | 23 +- .../mlx.core.take_along_axis.html | 23 +- .../python/_autosummary/mlx.core.tan.html | 23 +- .../python/_autosummary/mlx.core.tanh.html | 23 +- .../_autosummary/mlx.core.tensordot.html | 23 +- .../python/_autosummary/mlx.core.tile.html | 23 +- .../python/_autosummary/mlx.core.topk.html | 23 +- .../_autosummary/mlx.core.transpose.html | 23 +- .../python/_autosummary/mlx.core.tri.html | 23 +- .../python/_autosummary/mlx.core.tril.html | 23 +- .../python/_autosummary/mlx.core.triu.html | 23 +- .../_autosummary/mlx.core.value_and_grad.html | 23 +- .../python/_autosummary/mlx.core.var.html | 23 +- .../python/_autosummary/mlx.core.vjp.html | 23 +- .../python/_autosummary/mlx.core.vmap.html | 23 +- .../python/_autosummary/mlx.core.where.html | 23 +- .../python/_autosummary/mlx.core.zeros.html | 23 +- .../_autosummary/mlx.core.zeros_like.html | 23 +- .../python/_autosummary/mlx.nn.quantize.html | 23 +- .../_autosummary/mlx.nn.value_and_grad.html | 23 +- .../mlx.optimizers.clip_grad_norm.html | 936 + .../_autosummary/mlx.utils.tree_flatten.html | 23 +- .../_autosummary/mlx.utils.tree_map.html | 23 +- .../mlx.utils.tree_map_with_path.html | 29 +- .../_autosummary/mlx.utils.tree_reduce.html | 940 + .../mlx.utils.tree_unflatten.html | 23 +- .../python/_autosummary/stream_class.html | 23 +- docs/build/html/python/array.html | 23 +- docs/build/html/python/data_types.html | 23 +- .../html/python/devices_and_streams.html | 23 +- docs/build/html/python/fast.html | 23 +- docs/build/html/python/fft.html | 23 +- docs/build/html/python/linalg.html | 23 +- docs/build/html/python/metal.html | 33 +- docs/build/html/python/nn.html | 23 +- .../python/nn/_autosummary/mlx.nn.ALiBi.html | 23 +- .../nn/_autosummary/mlx.nn.AvgPool1d.html | 23 +- .../nn/_autosummary/mlx.nn.AvgPool2d.html | 23 +- .../nn/_autosummary/mlx.nn.BatchNorm.html | 23 +- .../python/nn/_autosummary/mlx.nn.Conv1d.html | 23 +- .../python/nn/_autosummary/mlx.nn.Conv2d.html | 23 +- .../nn/_autosummary/mlx.nn.Dropout.html | 23 +- .../nn/_autosummary/mlx.nn.Dropout2d.html | 23 +- .../nn/_autosummary/mlx.nn.Dropout3d.html | 23 +- .../nn/_autosummary/mlx.nn.Embedding.html | 23 +- .../python/nn/_autosummary/mlx.nn.GELU.html | 23 +- .../python/nn/_autosummary/mlx.nn.GRU.html | 23 +- .../nn/_autosummary/mlx.nn.GroupNorm.html | 23 +- .../nn/_autosummary/mlx.nn.InstanceNorm.html | 23 +- .../python/nn/_autosummary/mlx.nn.LSTM.html | 23 +- .../nn/_autosummary/mlx.nn.LayerNorm.html | 23 +- .../python/nn/_autosummary/mlx.nn.Linear.html | 23 +- .../nn/_autosummary/mlx.nn.MaxPool1d.html | 23 +- .../nn/_autosummary/mlx.nn.MaxPool2d.html | 23 +- .../python/nn/_autosummary/mlx.nn.Mish.html | 23 +- .../nn/_autosummary/mlx.nn.Module.apply.html | 23 +- .../mlx.nn.Module.apply_to_modules.html | 23 +- .../_autosummary/mlx.nn.Module.children.html | 23 +- .../nn/_autosummary/mlx.nn.Module.eval.html | 23 +- .../mlx.nn.Module.filter_and_map.html | 23 +- .../nn/_autosummary/mlx.nn.Module.freeze.html | 23 +- .../mlx.nn.Module.leaf_modules.html | 23 +- .../mlx.nn.Module.load_weights.html | 23 +- .../_autosummary/mlx.nn.Module.modules.html | 23 +- .../mlx.nn.Module.named_modules.html | 23 +- .../mlx.nn.Module.parameters.html | 23 +- .../mlx.nn.Module.save_weights.html | 23 +- .../_autosummary/mlx.nn.Module.set_dtype.html | 23 +- .../nn/_autosummary/mlx.nn.Module.state.html | 23 +- .../nn/_autosummary/mlx.nn.Module.train.html | 23 +- .../mlx.nn.Module.trainable_parameters.html | 23 +- .../_autosummary/mlx.nn.Module.training.html | 23 +- .../_autosummary/mlx.nn.Module.unfreeze.html | 23 +- .../nn/_autosummary/mlx.nn.Module.update.html | 23 +- .../mlx.nn.Module.update_modules.html | 23 +- .../mlx.nn.MultiHeadAttention.html | 23 +- .../python/nn/_autosummary/mlx.nn.PReLU.html | 23 +- .../mlx.nn.QuantizedEmbedding.html | 23 +- .../_autosummary/mlx.nn.QuantizedLinear.html | 23 +- .../nn/_autosummary/mlx.nn.RMSNorm.html | 23 +- .../python/nn/_autosummary/mlx.nn.RNN.html | 23 +- .../python/nn/_autosummary/mlx.nn.ReLU.html | 23 +- .../python/nn/_autosummary/mlx.nn.RoPE.html | 23 +- .../python/nn/_autosummary/mlx.nn.SELU.html | 23 +- .../nn/_autosummary/mlx.nn.Sequential.html | 23 +- .../python/nn/_autosummary/mlx.nn.SiLU.html | 23 +- .../mlx.nn.SinusoidalPositionalEncoding.html | 23 +- .../nn/_autosummary/mlx.nn.Softshrink.html | 23 +- .../python/nn/_autosummary/mlx.nn.Step.html | 23 +- .../nn/_autosummary/mlx.nn.Transformer.html | 23 +- .../nn/_autosummary/mlx.nn.Upsample.html | 23 +- .../nn/_autosummary/mlx.nn.init.constant.html | 23 +- .../mlx.nn.init.glorot_normal.html | 23 +- .../mlx.nn.init.glorot_uniform.html | 23 +- .../_autosummary/mlx.nn.init.he_normal.html | 23 +- .../_autosummary/mlx.nn.init.he_uniform.html | 23 +- .../nn/_autosummary/mlx.nn.init.identity.html | 23 +- .../nn/_autosummary/mlx.nn.init.normal.html | 23 +- .../nn/_autosummary/mlx.nn.init.uniform.html | 23 +- .../nn/_autosummary_functions/mlx.nn.elu.html | 23 +- .../_autosummary_functions/mlx.nn.gelu.html | 23 +- .../mlx.nn.gelu_approx.html | 23 +- .../mlx.nn.gelu_fast_approx.html | 23 +- .../nn/_autosummary_functions/mlx.nn.glu.html | 23 +- .../mlx.nn.hardswish.html | 23 +- .../mlx.nn.leaky_relu.html | 23 +- .../mlx.nn.log_sigmoid.html | 23 +- .../mlx.nn.log_softmax.html | 23 +- .../mlx.nn.losses.binary_cross_entropy.html | 23 +- .../mlx.nn.losses.cosine_similarity_loss.html | 23 +- .../mlx.nn.losses.cross_entropy.html | 23 +- .../mlx.nn.losses.gaussian_nll_loss.html | 23 +- .../mlx.nn.losses.hinge_loss.html | 23 +- .../mlx.nn.losses.huber_loss.html | 23 +- .../mlx.nn.losses.kl_div_loss.html | 23 +- .../mlx.nn.losses.l1_loss.html | 23 +- .../mlx.nn.losses.log_cosh_loss.html | 23 +- .../mlx.nn.losses.margin_ranking_loss.html | 23 +- .../mlx.nn.losses.mse_loss.html | 23 +- .../mlx.nn.losses.nll_loss.html | 23 +- .../mlx.nn.losses.smooth_l1_loss.html | 23 +- .../mlx.nn.losses.triplet_loss.html | 23 +- .../_autosummary_functions/mlx.nn.mish.html | 23 +- .../_autosummary_functions/mlx.nn.prelu.html | 23 +- .../_autosummary_functions/mlx.nn.relu.html | 23 +- .../_autosummary_functions/mlx.nn.relu6.html | 23 +- .../_autosummary_functions/mlx.nn.selu.html | 23 +- .../mlx.nn.sigmoid.html | 23 +- .../_autosummary_functions/mlx.nn.silu.html | 23 +- .../mlx.nn.softmax.html | 23 +- .../mlx.nn.softplus.html | 23 +- .../mlx.nn.softshrink.html | 23 +- .../_autosummary_functions/mlx.nn.step.html | 23 +- .../_autosummary_functions/mlx.nn.tanh.html | 23 +- docs/build/html/python/nn/functions.html | 23 +- docs/build/html/python/nn/init.html | 23 +- docs/build/html/python/nn/layers.html | 23 +- docs/build/html/python/nn/losses.html | 23 +- docs/build/html/python/nn/module.html | 23 +- docs/build/html/python/ops.html | 228 +- docs/build/html/python/optimizers.html | 32 +- .../_autosummary/mlx.optimizers.AdaDelta.html | 23 +- .../mlx.optimizers.Adafactor.html | 23 +- .../_autosummary/mlx.optimizers.Adagrad.html | 23 +- .../_autosummary/mlx.optimizers.Adam.html | 23 +- .../_autosummary/mlx.optimizers.AdamW.html | 23 +- .../_autosummary/mlx.optimizers.Adamax.html | 23 +- .../_autosummary/mlx.optimizers.Lion.html | 23 +- ....optimizers.Optimizer.apply_gradients.html | 23 +- .../mlx.optimizers.Optimizer.init.html | 23 +- .../mlx.optimizers.Optimizer.state.html | 23 +- .../mlx.optimizers.Optimizer.update.html | 23 +- .../_autosummary/mlx.optimizers.RMSprop.html | 23 +- .../_autosummary/mlx.optimizers.SGD.html | 23 +- .../mlx.optimizers.cosine_decay.html | 23 +- .../mlx.optimizers.exponential_decay.html | 23 +- .../mlx.optimizers.join_schedules.html | 23 +- .../mlx.optimizers.linear_schedule.html | 23 +- .../mlx.optimizers.step_decay.html | 29 +- .../python/optimizers/common_optimizers.html | 23 +- .../html/python/optimizers/optimizer.html | 23 +- .../html/python/optimizers/schedulers.html | 23 +- docs/build/html/python/random.html | 23 +- docs/build/html/python/transforms.html | 23 +- docs/build/html/python/tree_utils.html | 32 +- docs/build/html/random_8h.html | 186 + docs/build/html/random_8h_source.html | 356 + docs/build/html/reduce__inst_8h.html | 357 + docs/build/html/reduce__inst_8h_source.html | 179 + ...d__dot__product__attention__params_8h.html | 97 + ..._product__attention__params_8h_source.html | 114 + docs/build/html/scheduler_8h.html | 135 + docs/build/html/scheduler_8h_source.html | 360 + docs/build/html/search.html | 23 +- docs/build/html/search/all_0.js | 12 + docs/build/html/search/all_1.js | 96 + docs/build/html/search/all_10.js | 60 + docs/build/html/search/all_11.js | 11 + docs/build/html/search/all_12.js | 71 + docs/build/html/search/all_13.js | 156 + docs/build/html/search/all_14.js | 66 + docs/build/html/search/all_15.js | 18 + docs/build/html/search/all_16.js | 24 + docs/build/html/search/all_17.js | 15 + docs/build/html/search/all_18.js | 4 + docs/build/html/search/all_19.js | 5 + docs/build/html/search/all_1a.js | 17 + docs/build/html/search/all_2.js | 68 + docs/build/html/search/all_3.js | 111 + docs/build/html/search/all_4.js | 61 + docs/build/html/search/all_5.js | 53 + docs/build/html/search/all_6.js | 54 + docs/build/html/search/all_7.js | 57 + docs/build/html/search/all_8.js | 12 + docs/build/html/search/all_9.js | 101 + docs/build/html/search/all_a.js | 7 + docs/build/html/search/all_b.js | 11 + docs/build/html/search/all_c.js | 72 + docs/build/html/search/all_d.js | 102 + docs/build/html/search/all_e.js | 36 + docs/build/html/search/all_f.js | 55 + docs/build/html/search/classes_0.js | 6 + docs/build/html/search/classes_1.js | 32 + docs/build/html/search/classes_10.js | 18 + docs/build/html/search/classes_11.js | 28 + docs/build/html/search/classes_12.js | 15 + docs/build/html/search/classes_13.js | 6 + docs/build/html/search/classes_14.js | 7 + docs/build/html/search/classes_15.js | 4 + docs/build/html/search/classes_2.js | 16 + docs/build/html/search/classes_3.js | 39 + docs/build/html/search/classes_4.js | 9 + docs/build/html/search/classes_5.js | 13 + docs/build/html/search/classes_6.js | 10 + docs/build/html/search/classes_7.js | 10 + docs/build/html/search/classes_8.js | 8 + docs/build/html/search/classes_9.js | 5 + docs/build/html/search/classes_a.js | 32 + docs/build/html/search/classes_b.js | 20 + docs/build/html/search/classes_c.js | 10 + docs/build/html/search/classes_d.js | 4 + docs/build/html/search/classes_e.js | 11 + docs/build/html/search/classes_f.js | 5 + docs/build/html/search/close.svg | 18 + docs/build/html/search/defines_0.js | 7 + docs/build/html/search/defines_1.js | 4 + docs/build/html/search/defines_2.js | 13 + docs/build/html/search/defines_3.js | 5 + docs/build/html/search/defines_4.js | 8 + docs/build/html/search/defines_5.js | 4 + docs/build/html/search/defines_6.js | 10 + docs/build/html/search/defines_7.js | 40 + docs/build/html/search/defines_8.js | 6 + docs/build/html/search/defines_9.js | 22 + docs/build/html/search/defines_a.js | 5 + docs/build/html/search/defines_b.js | 4 + docs/build/html/search/enums_0.js | 4 + docs/build/html/search/enums_1.js | 6 + docs/build/html/search/enums_2.js | 5 + docs/build/html/search/enums_3.js | 4 + docs/build/html/search/enums_4.js | 4 + docs/build/html/search/enums_5.js | 5 + docs/build/html/search/enums_6.js | 4 + docs/build/html/search/enums_7.js | 4 + docs/build/html/search/enumvalues_0.js | 7 + docs/build/html/search/enumvalues_1.js | 6 + docs/build/html/search/enumvalues_10.js | 10 + docs/build/html/search/enumvalues_11.js | 5 + docs/build/html/search/enumvalues_12.js | 4 + docs/build/html/search/enumvalues_2.js | 10 + docs/build/html/search/enumvalues_3.js | 9 + docs/build/html/search/enumvalues_4.js | 5 + docs/build/html/search/enumvalues_5.js | 7 + docs/build/html/search/enumvalues_6.js | 10 + docs/build/html/search/enumvalues_7.js | 10 + docs/build/html/search/enumvalues_8.js | 4 + docs/build/html/search/enumvalues_9.js | 5 + docs/build/html/search/enumvalues_a.js | 7 + docs/build/html/search/enumvalues_b.js | 4 + docs/build/html/search/enumvalues_c.js | 4 + docs/build/html/search/enumvalues_d.js | 4 + docs/build/html/search/enumvalues_e.js | 7 + docs/build/html/search/enumvalues_f.js | 5 + docs/build/html/search/files_0.js | 7 + docs/build/html/search/files_1.js | 7 + docs/build/html/search/files_10.js | 5 + docs/build/html/search/files_2.js | 10 + docs/build/html/search/files_3.js | 6 + docs/build/html/search/files_4.js | 6 + docs/build/html/search/files_5.js | 7 + docs/build/html/search/files_6.js | 6 + docs/build/html/search/files_7.js | 4 + docs/build/html/search/files_8.js | 5 + docs/build/html/search/files_9.js | 10 + docs/build/html/search/files_a.js | 8 + docs/build/html/search/files_b.js | 4 + docs/build/html/search/files_c.js | 6 + docs/build/html/search/files_d.js | 6 + docs/build/html/search/files_e.js | 6 + docs/build/html/search/files_f.js | 7 + docs/build/html/search/functions_0.js | 7 + docs/build/html/search/functions_1.js | 69 + docs/build/html/search/functions_10.js | 29 + docs/build/html/search/functions_11.js | 9 + docs/build/html/search/functions_12.js | 49 + docs/build/html/search/functions_13.js | 122 + docs/build/html/search/functions_14.js | 36 + docs/build/html/search/functions_15.js | 7 + docs/build/html/search/functions_16.js | 12 + docs/build/html/search/functions_17.js | 7 + docs/build/html/search/functions_18.js | 5 + docs/build/html/search/functions_19.js | 17 + docs/build/html/search/functions_2.js | 25 + docs/build/html/search/functions_3.js | 65 + docs/build/html/search/functions_4.js | 37 + docs/build/html/search/functions_5.js | 42 + docs/build/html/search/functions_6.js | 33 + docs/build/html/search/functions_7.js | 36 + docs/build/html/search/functions_8.js | 4 + docs/build/html/search/functions_9.js | 39 + docs/build/html/search/functions_a.js | 4 + docs/build/html/search/functions_b.js | 6 + docs/build/html/search/functions_c.js | 42 + docs/build/html/search/functions_d.js | 42 + docs/build/html/search/functions_e.js | 23 + docs/build/html/search/functions_f.js | 46 + docs/build/html/search/groups_0.js | 4 + docs/build/html/search/groups_1.js | 4 + docs/build/html/search/groups_2.js | 4 + docs/build/html/search/mag.svg | 24 + docs/build/html/search/mag_d.svg | 24 + docs/build/html/search/mag_sel.svg | 31 + docs/build/html/search/mag_seld.svg | 31 + docs/build/html/search/namespaces_0.js | 23 + docs/build/html/search/namespaces_1.js | 6 + docs/build/html/search/related_0.js | 4 + docs/build/html/search/related_1.js | 5 + docs/build/html/search/search.css | 291 + docs/build/html/search/search.js | 694 + docs/build/html/search/searchdata.js | 48 + docs/build/html/search/typedefs_0.js | 5 + docs/build/html/search/typedefs_1.js | 4 + docs/build/html/search/typedefs_2.js | 6 + docs/build/html/search/typedefs_3.js | 4 + docs/build/html/search/typedefs_4.js | 4 + docs/build/html/search/typedefs_5.js | 5 + docs/build/html/search/typedefs_6.js | 4 + docs/build/html/search/typedefs_7.js | 5 + docs/build/html/search/typedefs_8.js | 6 + docs/build/html/search/typedefs_9.js | 4 + docs/build/html/search/typedefs_a.js | 8 + docs/build/html/search/typedefs_b.js | 4 + docs/build/html/search/typedefs_c.js | 6 + docs/build/html/search/variables_0.js | 10 + docs/build/html/search/variables_1.js | 25 + docs/build/html/search/variables_10.js | 15 + docs/build/html/search/variables_11.js | 26 + docs/build/html/search/variables_12.js | 25 + docs/build/html/search/variables_13.js | 9 + docs/build/html/search/variables_14.js | 8 + docs/build/html/search/variables_15.js | 10 + docs/build/html/search/variables_2.js | 20 + docs/build/html/search/variables_3.js | 10 + docs/build/html/search/variables_4.js | 4 + docs/build/html/search/variables_5.js | 17 + docs/build/html/search/variables_6.js | 11 + docs/build/html/search/variables_7.js | 23 + docs/build/html/search/variables_8.js | 6 + docs/build/html/search/variables_9.js | 6 + docs/build/html/search/variables_a.js | 7 + docs/build/html/search/variables_b.js | 18 + docs/build/html/search/variables_c.js | 13 + docs/build/html/search/variables_d.js | 8 + docs/build/html/search/variables_e.js | 6 + docs/build/html/search/variables_f.js | 5 + docs/build/html/searchindex.js | 2 +- docs/build/html/splitbar.png | Bin 0 -> 314 bytes docs/build/html/splitbard.png | Bin 0 -> 282 bytes docs/build/html/stream_8h.html | 127 + docs/build/html/stream_8h_source.html | 146 + .../struct___m_l_x___b_float16-members.html | 101 + .../html/struct___m_l_x___b_float16.html | 534 + ...b_float16_1_1bits__to__bfloat__struct.html | 92 + docs/build/html/struct_abs-members.html | 93 + docs/build/html/struct_abs.html | 306 + docs/build/html/struct_add-members.html | 87 + docs/build/html/struct_add.html | 130 + docs/build/html/struct_and-members.html | 92 + docs/build/html/struct_and.html | 291 + docs/build/html/struct_arc_cos-members.html | 87 + docs/build/html/struct_arc_cos.html | 126 + docs/build/html/struct_arc_cosh-members.html | 87 + docs/build/html/struct_arc_cosh.html | 126 + docs/build/html/struct_arc_sin-members.html | 87 + docs/build/html/struct_arc_sin.html | 126 + docs/build/html/struct_arc_sinh-members.html | 87 + docs/build/html/struct_arc_sinh.html | 126 + docs/build/html/struct_arc_tan-members.html | 87 + docs/build/html/struct_arc_tan.html | 126 + docs/build/html/struct_arc_tan2-members.html | 87 + docs/build/html/struct_arc_tan2.html | 130 + docs/build/html/struct_arc_tanh-members.html | 87 + docs/build/html/struct_arc_tanh.html | 126 + .../html/struct_bitwise_and-members.html | 87 + docs/build/html/struct_bitwise_and.html | 130 + .../build/html/struct_bitwise_or-members.html | 87 + docs/build/html/struct_bitwise_or.html | 130 + .../html/struct_bitwise_xor-members.html | 87 + docs/build/html/struct_bitwise_xor.html | 130 + docs/build/html/struct_ceil-members.html | 96 + docs/build/html/struct_ceil.html | 396 + docs/build/html/struct_conjugate-members.html | 87 + docs/build/html/struct_conjugate.html | 123 + docs/build/html/struct_cos-members.html | 88 + docs/build/html/struct_cos.html | 156 + docs/build/html/struct_cosh-members.html | 88 + docs/build/html/struct_cosh.html | 156 + docs/build/html/struct_divide-members.html | 87 + docs/build/html/struct_divide.html | 130 + docs/build/html/struct_equal-members.html | 87 + docs/build/html/struct_equal.html | 130 + docs/build/html/struct_erf-members.html | 87 + docs/build/html/struct_erf.html | 126 + docs/build/html/struct_erf_inv-members.html | 87 + docs/build/html/struct_erf_inv.html | 126 + docs/build/html/struct_exp-members.html | 88 + docs/build/html/struct_exp.html | 156 + docs/build/html/struct_expm1-members.html | 87 + docs/build/html/struct_expm1.html | 126 + docs/build/html/struct_floor-members.html | 96 + docs/build/html/struct_floor.html | 396 + docs/build/html/struct_greater-members.html | 87 + docs/build/html/struct_greater.html | 130 + .../html/struct_greater_equal-members.html | 87 + docs/build/html/struct_greater_equal.html | 130 + docs/build/html/struct_indices-members.html | 90 + docs/build/html/struct_indices.html | 168 + .../build/html/struct_left_shift-members.html | 87 + docs/build/html/struct_left_shift.html | 130 + docs/build/html/struct_less-members.html | 87 + docs/build/html/struct_less.html | 130 + .../build/html/struct_less_equal-members.html | 87 + docs/build/html/struct_less_equal.html | 130 + docs/build/html/struct_limits-members.html | 90 + docs/build/html/struct_limits.html | 200 + ...t_limits_3_01bfloat16__t_01_4-members.html | 90 + .../struct_limits_3_01bfloat16__t_01_4.html | 192 + .../struct_limits_3_01bool_01_4-members.html | 88 + .../html/struct_limits_3_01bool_01_4.html | 144 + .../struct_limits_3_01float_01_4-members.html | 90 + .../html/struct_limits_3_01float_01_4.html | 192 + .../struct_limits_3_01half_01_4-members.html | 90 + .../html/struct_limits_3_01half_01_4.html | 192 + ...ruct_limits_3_01int16__t_01_4-members.html | 90 + .../html/struct_limits_3_01int16__t_01_4.html | 192 + ...ruct_limits_3_01int32__t_01_4-members.html | 90 + .../html/struct_limits_3_01int32__t_01_4.html | 192 + ...ruct_limits_3_01int64__t_01_4-members.html | 90 + .../html/struct_limits_3_01int64__t_01_4.html | 192 + ...truct_limits_3_01int8__t_01_4-members.html | 90 + .../html/struct_limits_3_01int8__t_01_4.html | 192 + ...uct_limits_3_01uint16__t_01_4-members.html | 90 + .../struct_limits_3_01uint16__t_01_4.html | 192 + ...uct_limits_3_01uint32__t_01_4-members.html | 90 + .../struct_limits_3_01uint32__t_01_4.html | 192 + ...uct_limits_3_01uint64__t_01_4-members.html | 90 + .../struct_limits_3_01uint64__t_01_4.html | 192 + ...ruct_limits_3_01uint8__t_01_4-members.html | 90 + .../html/struct_limits_3_01uint8__t_01_4.html | 192 + docs/build/html/struct_log-members.html | 87 + docs/build/html/struct_log.html | 126 + docs/build/html/struct_log10-members.html | 87 + docs/build/html/struct_log10.html | 126 + docs/build/html/struct_log1p-members.html | 87 + docs/build/html/struct_log1p.html | 126 + docs/build/html/struct_log2-members.html | 87 + docs/build/html/struct_log2.html | 126 + .../html/struct_log_add_exp-members.html | 87 + docs/build/html/struct_log_add_exp.html | 130 + .../html/struct_logical_and-members.html | 87 + docs/build/html/struct_logical_and.html | 130 + .../html/struct_logical_not-members.html | 87 + docs/build/html/struct_logical_not.html | 126 + .../build/html/struct_logical_or-members.html | 87 + docs/build/html/struct_logical_or.html | 130 + .../struct_m_l_x_conv_params-members.html | 101 + docs/build/html/struct_m_l_x_conv_params.html | 366 + ..._dot_product_attention_params-members.html | 91 + ...x_scaled_dot_product_attention_params.html | 176 + docs/build/html/struct_max-members.html | 90 + docs/build/html/struct_max.html | 233 + docs/build/html/struct_maximum-members.html | 89 + docs/build/html/struct_maximum.html | 198 + docs/build/html/struct_min-members.html | 90 + docs/build/html/struct_min.html | 233 + docs/build/html/struct_minimum-members.html | 89 + docs/build/html/struct_minimum.html | 198 + docs/build/html/struct_multiply-members.html | 87 + docs/build/html/struct_multiply.html | 130 + .../build/html/struct_na_n_equal-members.html | 88 + docs/build/html/struct_na_n_equal.html | 164 + docs/build/html/struct_negative-members.html | 87 + docs/build/html/struct_negative.html | 126 + docs/build/html/struct_none-members.html | 87 + docs/build/html/struct_none.html | 135 + docs/build/html/struct_not_equal-members.html | 88 + docs/build/html/struct_not_equal.html | 164 + docs/build/html/struct_or-members.html | 92 + docs/build/html/struct_or.html | 291 + docs/build/html/struct_power-members.html | 89 + docs/build/html/struct_power.html | 198 + docs/build/html/struct_prod-members.html | 90 + docs/build/html/struct_prod.html | 233 + docs/build/html/struct_remainder-members.html | 90 + docs/build/html/struct_remainder.html | 232 + .../html/struct_right_shift-members.html | 87 + docs/build/html/struct_right_shift.html | 130 + docs/build/html/struct_round-members.html | 88 + docs/build/html/struct_round.html | 156 + docs/build/html/struct_rsqrt-members.html | 87 + docs/build/html/struct_rsqrt.html | 126 + docs/build/html/struct_select-members.html | 87 + docs/build/html/struct_select.html | 135 + docs/build/html/struct_sigmoid-members.html | 87 + docs/build/html/struct_sigmoid.html | 126 + docs/build/html/struct_sign-members.html | 88 + docs/build/html/struct_sign.html | 156 + docs/build/html/struct_sin-members.html | 88 + docs/build/html/struct_sin.html | 156 + docs/build/html/struct_sinh-members.html | 88 + docs/build/html/struct_sinh.html | 156 + docs/build/html/struct_sqrt-members.html | 87 + docs/build/html/struct_sqrt.html | 126 + docs/build/html/struct_square-members.html | 87 + docs/build/html/struct_square.html | 126 + docs/build/html/struct_subtract-members.html | 87 + docs/build/html/struct_subtract.html | 130 + docs/build/html/struct_sum-members.html | 90 + docs/build/html/struct_sum.html | 233 + docs/build/html/struct_tan-members.html | 88 + docs/build/html/struct_tan.html | 156 + docs/build/html/struct_tanh-members.html | 88 + docs/build/html/struct_tanh.html | 156 + .../html/structcomplex64__t-members.html | 97 + docs/build/html/structcomplex64__t.html | 405 + ...ts__impl_3_01bfloat16__t_01_4-members.html | 107 + ...ic__limits__impl_3_01bfloat16__t_01_4.html | 546 + ...ric__limits__impl_3_01bfloat16__t_01_4.png | Bin 0 -> 833 bytes ...1core_1_1___m_l_x___b_float16-members.html | 97 + ...ctmlx_1_1core_1_1___m_l_x___b_float16.html | 284 + ...1_1core_1_1___m_l_x___float16-members.html | 97 + ...ructmlx_1_1core_1_1___m_l_x___float16.html | 284 + .../structmlx_1_1core_1_1_device-members.html | 96 + .../html/structmlx_1_1core_1_1_device.html | 255 + .../structmlx_1_1core_1_1_dtype-members.html | 97 + .../html/structmlx_1_1core_1_1_dtype.html | 344 + ...uctmlx_1_1core_1_1_node_namer-members.html | 92 + .../structmlx_1_1core_1_1_node_namer.html | 140 + ...x_1_1core_1_1_print_formatter-members.html | 102 + ...structmlx_1_1core_1_1_print_formatter.html | 462 + ...lx_1_1core_1_1_reduction_plan-members.html | 95 + .../structmlx_1_1core_1_1_reduction_plan.html | 216 + .../structmlx_1_1core_1_1_stream-members.html | 93 + .../html/structmlx_1_1core_1_1_stream.html | 168 + ...lx_1_1core_1_1_stream_context-members.html | 92 + .../structmlx_1_1core_1_1_stream_context.html | 154 + ...mlx_1_1core_1_1_type_to_dtype-members.html | 91 + .../structmlx_1_1core_1_1_type_to_dtype.html | 121 + ...e_1_1array_1_1_array_iterator-members.html | 100 + ...x_1_1core_1_1array_1_1_array_iterator.html | 341 + ...mlx_1_1core_1_1array_1_1_data-members.html | 96 + .../structmlx_1_1core_1_1array_1_1_data.html | 250 + ...lx_1_1core_1_1array_1_1_flags-members.html | 93 + .../structmlx_1_1core_1_1array_1_1_flags.html | 148 + ...tmlx_1_1core_1_1complex128__t-members.html | 94 + .../structmlx_1_1core_1_1complex128__t.html | 222 + .../structmlx_1_1core_1_1complex128__t.png | Bin 0 -> 614 bytes ...ctmlx_1_1core_1_1complex64__t-members.html | 94 + .../structmlx_1_1core_1_1complex64__t.html | 222 + .../structmlx_1_1core_1_1complex64__t.png | Bin 0 -> 618 bytes ...mlx_1_1core_1_1detail_1_1_abs-members.html | 96 + .../structmlx_1_1core_1_1detail_1_1_abs.html | 265 + ...mlx_1_1core_1_1detail_1_1_add-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_add.html | 134 + ...1_1core_1_1detail_1_1_arc_cos-members.html | 91 + ...ructmlx_1_1core_1_1detail_1_1_arc_cos.html | 130 + ..._1core_1_1detail_1_1_arc_cosh-members.html | 91 + ...uctmlx_1_1core_1_1detail_1_1_arc_cosh.html | 130 + ...1_1core_1_1detail_1_1_arc_sin-members.html | 91 + ...ructmlx_1_1core_1_1detail_1_1_arc_sin.html | 130 + ..._1core_1_1detail_1_1_arc_sinh-members.html | 91 + ...uctmlx_1_1core_1_1detail_1_1_arc_sinh.html | 130 + ...1_1core_1_1detail_1_1_arc_tan-members.html | 91 + ...ructmlx_1_1core_1_1detail_1_1_arc_tan.html | 130 + ..._1core_1_1detail_1_1_arc_tan2-members.html | 91 + ...uctmlx_1_1core_1_1detail_1_1_arc_tan2.html | 134 + ..._1core_1_1detail_1_1_arc_tanh-members.html | 91 + ...uctmlx_1_1core_1_1detail_1_1_arc_tanh.html | 130 + ...ore_1_1detail_1_1_bitwise_and-members.html | 91 + ...mlx_1_1core_1_1detail_1_1_bitwise_and.html | 134 + ...core_1_1detail_1_1_bitwise_or-members.html | 91 + ...tmlx_1_1core_1_1detail_1_1_bitwise_or.html | 134 + ...ore_1_1detail_1_1_bitwise_xor-members.html | 91 + ...mlx_1_1core_1_1detail_1_1_bitwise_xor.html | 134 + ...lx_1_1core_1_1detail_1_1_ceil-members.html | 100 + .../structmlx_1_1core_1_1detail_1_1_ceil.html | 373 + ...1core_1_1detail_1_1_conjugate-members.html | 91 + ...ctmlx_1_1core_1_1detail_1_1_conjugate.html | 127 + ...mlx_1_1core_1_1detail_1_1_cos-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_cos.html | 130 + ...lx_1_1core_1_1detail_1_1_cosh-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_cosh.html | 130 + ..._1_1core_1_1detail_1_1_divide-members.html | 91 + ...tructmlx_1_1core_1_1detail_1_1_divide.html | 134 + ...x_1_1core_1_1detail_1_1_equal-members.html | 91 + ...structmlx_1_1core_1_1detail_1_1_equal.html | 134 + ...mlx_1_1core_1_1detail_1_1_erf-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_erf.html | 130 + ...1_1core_1_1detail_1_1_erf_inv-members.html | 91 + ...ructmlx_1_1core_1_1detail_1_1_erf_inv.html | 130 + ...mlx_1_1core_1_1detail_1_1_exp-members.html | 92 + .../structmlx_1_1core_1_1detail_1_1_exp.html | 157 + ...x_1_1core_1_1detail_1_1_expm1-members.html | 91 + ...structmlx_1_1core_1_1detail_1_1_expm1.html | 130 + ...x_1_1core_1_1detail_1_1_floor-members.html | 100 + ...structmlx_1_1core_1_1detail_1_1_floor.html | 373 + ...1_1core_1_1detail_1_1_greater-members.html | 91 + ...ructmlx_1_1core_1_1detail_1_1_greater.html | 134 + ...e_1_1detail_1_1_greater_equal-members.html | 91 + ...x_1_1core_1_1detail_1_1_greater_equal.html | 134 + ...core_1_1detail_1_1_in_tracing-members.html | 93 + ...tmlx_1_1core_1_1detail_1_1_in_tracing.html | 186 + ...core_1_1detail_1_1_left_shift-members.html | 91 + ...tmlx_1_1core_1_1detail_1_1_left_shift.html | 134 + ...lx_1_1core_1_1detail_1_1_less-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_less.html | 134 + ...core_1_1detail_1_1_less_equal-members.html | 91 + ...tmlx_1_1core_1_1detail_1_1_less_equal.html | 134 + ...mlx_1_1core_1_1detail_1_1_log-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_log.html | 130 + ...x_1_1core_1_1detail_1_1_log10-members.html | 91 + ...structmlx_1_1core_1_1detail_1_1_log10.html | 130 + ...x_1_1core_1_1detail_1_1_log1p-members.html | 91 + ...structmlx_1_1core_1_1detail_1_1_log1p.html | 130 + ...lx_1_1core_1_1detail_1_1_log2-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_log2.html | 130 + ...ore_1_1detail_1_1_log_add_exp-members.html | 91 + ...mlx_1_1core_1_1detail_1_1_log_add_exp.html | 134 + ...ore_1_1detail_1_1_logical_and-members.html | 91 + ...mlx_1_1core_1_1detail_1_1_logical_and.html | 134 + ...ore_1_1detail_1_1_logical_not-members.html | 91 + ...mlx_1_1core_1_1detail_1_1_logical_not.html | 130 + ...core_1_1detail_1_1_logical_or-members.html | 91 + ...tmlx_1_1core_1_1detail_1_1_logical_or.html | 134 + ...1_1core_1_1detail_1_1_maximum-members.html | 92 + ...ructmlx_1_1core_1_1detail_1_1_maximum.html | 168 + ...1_1core_1_1detail_1_1_minimum-members.html | 92 + ...ructmlx_1_1core_1_1detail_1_1_minimum.html | 168 + ..._1core_1_1detail_1_1_multiply-members.html | 91 + ...uctmlx_1_1core_1_1detail_1_1_multiply.html | 134 + ...core_1_1detail_1_1_na_n_equal-members.html | 91 + ...tmlx_1_1core_1_1detail_1_1_na_n_equal.html | 134 + ..._1core_1_1detail_1_1_negative-members.html | 91 + ...uctmlx_1_1core_1_1detail_1_1_negative.html | 130 + ...1core_1_1detail_1_1_not_equal-members.html | 91 + ...ctmlx_1_1core_1_1detail_1_1_not_equal.html | 134 + ...x_1_1core_1_1detail_1_1_power-members.html | 92 + ...structmlx_1_1core_1_1detail_1_1_power.html | 168 + ...1core_1_1detail_1_1_remainder-members.html | 94 + ...ctmlx_1_1core_1_1detail_1_1_remainder.html | 233 + ...ore_1_1detail_1_1_right_shift-members.html | 91 + ...mlx_1_1core_1_1detail_1_1_right_shift.html | 134 + ...x_1_1core_1_1detail_1_1_round-members.html | 92 + ...structmlx_1_1core_1_1detail_1_1_round.html | 157 + ...x_1_1core_1_1detail_1_1_rsqrt-members.html | 91 + ...structmlx_1_1core_1_1detail_1_1_rsqrt.html | 130 + ..._1_1core_1_1detail_1_1_select-members.html | 91 + ...tructmlx_1_1core_1_1detail_1_1_select.html | 139 + ...1_1core_1_1detail_1_1_sigmoid-members.html | 91 + ...ructmlx_1_1core_1_1detail_1_1_sigmoid.html | 130 + ...lx_1_1core_1_1detail_1_1_sign-members.html | 95 + .../structmlx_1_1core_1_1detail_1_1_sign.html | 238 + ...mlx_1_1core_1_1detail_1_1_sin-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_sin.html | 130 + ...lx_1_1core_1_1detail_1_1_sinh-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_sinh.html | 130 + ...lx_1_1core_1_1detail_1_1_sqrt-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_sqrt.html | 130 + ..._1_1core_1_1detail_1_1_square-members.html | 91 + ...tructmlx_1_1core_1_1detail_1_1_square.html | 130 + ..._1core_1_1detail_1_1_subtract-members.html | 91 + ...uctmlx_1_1core_1_1detail_1_1_subtract.html | 134 + ...mlx_1_1core_1_1detail_1_1_tan-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_tan.html | 130 + ...lx_1_1core_1_1detail_1_1_tanh-members.html | 91 + .../structmlx_1_1core_1_1detail_1_1_tanh.html | 130 + ..._1_1metal_1_1_command_encoder-members.html | 100 + ..._1_1core_1_1metal_1_1_command_encoder.html | 387 + ...ncoder_1_1_concurrent_context-members.html | 92 + ...ommand_encoder_1_1_concurrent_context.html | 154 + ..._1scheduler_1_1_stream_thread-members.html | 100 + ..._1core_1_1scheduler_1_1_stream_thread.html | 313 + ...mlx_1_1steel_1_1_accum_helper-members.html | 91 + .../structmlx_1_1steel_1_1_accum_helper.html | 118 + ...mlx_1_1steel_1_1_block_loader-members.html | 103 + .../structmlx_1_1steel_1_1_block_loader.html | 409 + ..._block_loader_1_1_read_vector-members.html | 91 + ...teel_1_1_block_loader_1_1_read_vector.html | 118 + ...tmlx_1_1steel_1_1_block_m_m_a-members.html | 115 + .../structmlx_1_1steel_1_1_block_m_m_a.html | 704 + ...lx_1_1steel_1_1_block_swizzle-members.html | 91 + .../structmlx_1_1steel_1_1_block_swizzle.html | 131 + ...x_1_1steel_1_1_channel_helper-members.html | 93 + ...structmlx_1_1steel_1_1_channel_helper.html | 154 + ...1_1_channel_helper_3_011_01_4-members.html | 93 + ..._1steel_1_1_channel_helper_3_011_01_4.html | 148 + ...1_1_channel_helper_3_012_01_4-members.html | 93 + ..._1steel_1_1_channel_helper_3_012_01_4.html | 148 + ...1_1_channel_helper_3_013_01_4-members.html | 93 + ..._1steel_1_1_channel_helper_3_013_01_4.html | 148 + ...1_1_channel_helper_3_014_01_4-members.html | 93 + ..._1steel_1_1_channel_helper_3_014_01_4.html | 148 + ...1_1_conv2_d_general_base_info-members.html | 92 + ..._1steel_1_1_conv2_d_general_base_info.html | 132 + ...1_conv2_d_general_jump_params-members.html | 98 + ...steel_1_1_conv2_d_general_jump_params.html | 228 + ..._d_input_block_loader_general-members.html | 114 + ..._1_conv2_d_input_block_loader_general.html | 610 + ...put_block_loader_large_filter-members.html | 112 + ...nv2_d_input_block_loader_large_filter.html | 564 + ...t_block_loader_small_channels-members.html | 111 + ...2_d_input_block_loader_small_channels.html | 546 + ...put_block_loader_small_filter-members.html | 112 + ...nv2_d_input_block_loader_small_filter.html | 569 + ...1_conv2_d_weight_block_loader-members.html | 110 + ...steel_1_1_conv2_d_weight_block_loader.html | 532 + ...d_weight_block_loader_general-members.html | 113 + ...1_conv2_d_weight_block_loader_general.html | 596 + ...t_block_loader_small_channels-members.html | 110 + ..._d_weight_block_loader_small_channels.html | 528 + ...el_1_1_g_e_m_m_add_m_m_params-members.html | 95 + ...x_1_1steel_1_1_g_e_m_m_add_m_m_params.html | 180 + ...x_1_1steel_1_1_g_e_m_m_kernel-members.html | 101 + ...structmlx_1_1steel_1_1_g_e_m_m_kernel.html | 456 + ...x_1_1steel_1_1_g_e_m_m_params-members.html | 104 + ...structmlx_1_1steel_1_1_g_e_m_m_params.html | 324 + ...el_1_1_g_e_m_m_spilt_k_params-members.html | 102 + ...x_1_1steel_1_1_g_e_m_m_spilt_k_params.html | 292 + ..._implicit_gemm_conv2_d_params-members.html | 100 + ...teel_1_1_implicit_gemm_conv2_d_params.html | 260 + ...structmlx_1_1steel_1_1_loop_alignment.html | 92 + ...lx_1_1steel_1_1_transform_add-members.html | 92 + .../structmlx_1_1steel_1_1_transform_add.html | 171 + ..._1_1steel_1_1_transform_axpby-members.html | 94 + ...tructmlx_1_1steel_1_1_transform_axpby.html | 208 + ...x_1_1steel_1_1_transform_none-members.html | 92 + ...structmlx_1_1steel_1_1_transform_none.html | 162 + .../build/html/structmlx__atomic-members.html | 87 + docs/build/html/structmlx__atomic.html | 114 + ..._atomic_3_01_t_01_4_01_4_01_4-members.html | 87 + ...__metal__atomic_3_01_t_01_4_01_4_01_4.html | 114 + ...etfft_1_1detail_1_1_exec_c2_c-members.html | 92 + ...ructpocketfft_1_1detail_1_1_exec_c2_c.html | 175 + ...etfft_1_1detail_1_1_exec_dcst-members.html | 94 + ...ructpocketfft_1_1detail_1_1_exec_dcst.html | 207 + ...ft_1_1detail_1_1_exec_hartley-members.html | 91 + ...tpocketfft_1_1detail_1_1_exec_hartley.html | 154 + ...etfft_1_1detail_1_1_exec_r2_r-members.html | 93 + ...ructpocketfft_1_1detail_1_1_exec_r2_r.html | 191 + ...cketfft_1_1detail_1_1_v_l_e_n-members.html | 91 + ...structpocketfft_1_1detail_1_1_v_l_e_n.html | 126 + ...ructpocketfft_1_1detail_1_1_v_t_y_p_e.html | 92 + ...cketfft_1_1detail_1_1add__vec-members.html | 91 + ...structpocketfft_1_1detail_1_1add__vec.html | 118 + ...ec_3_01cmplx_3_01_t_01_4_01_4-members.html | 91 + ..._1add__vec_3_01cmplx_3_01_t_01_4_01_4.html | 118 + ...tpocketfft_1_1detail_1_1cmplx-members.html | 106 + .../structpocketfft_1_1detail_1_1cmplx.html | 588 + ...reading_1_1aligned__allocator-members.html | 95 + ...il_1_1threading_1_1aligned__allocator.html | 247 + ...ctpocketfft_1_1detail_1_1util-members.html | 99 + .../structpocketfft_1_1detail_1_1util.html | 409 + docs/build/html/sync_off.png | Bin 0 -> 853 bytes docs/build/html/sync_on.png | Bin 0 -> 845 bytes docs/build/html/tab_a.png | Bin 0 -> 142 bytes docs/build/html/tab_ad.png | Bin 0 -> 135 bytes docs/build/html/tab_b.png | Bin 0 -> 169 bytes docs/build/html/tab_bd.png | Bin 0 -> 173 bytes docs/build/html/tab_h.png | Bin 0 -> 177 bytes docs/build/html/tab_hd.png | Bin 0 -> 180 bytes docs/build/html/tab_s.png | Bin 0 -> 184 bytes docs/build/html/tab_sd.png | Bin 0 -> 188 bytes docs/build/html/tabs.css | 1 + docs/build/html/threefry_8h.html | 110 + docs/build/html/threefry_8h_source.html | 111 + docs/build/html/topics.html | 88 + docs/build/html/transforms_8h.html | 172 + docs/build/html/transforms_8h_source.html | 248 + docs/build/html/transforms__impl_8h.html | 123 + .../html/transforms__impl_8h_source.html | 158 + docs/build/html/types_2bf16_8h.html | 743 + docs/build/html/types_2bf16_8h_source.html | 322 + docs/build/html/types_2complex_8h.html | 256 + docs/build/html/types_2complex_8h_source.html | 249 + .../html/unionbool4__or__uint-members.html | 88 + docs/build/html/unionbool4__or__uint.html | 128 + ...re_1_1detail_1_1_int_or_float-members.html | 92 + ...lx_1_1core_1_1detail_1_1_int_or_float.html | 132 + docs/build/html/usage/compile.html | 23 +- .../build/html/usage/function_transforms.html | 23 +- docs/build/html/usage/indexing.html | 23 +- docs/build/html/usage/lazy_evaluation.html | 23 +- docs/build/html/usage/numpy.html | 23 +- docs/build/html/usage/quick_start.html | 23 +- docs/build/html/usage/saving_and_loading.html | 23 +- docs/build/html/usage/unified_memory.html | 23 +- docs/build/html/usage/using_streams.html | 23 +- docs/build/html/utils_8h.html | 172 + docs/build/html/utils_8h_source.html | 276 + 1835 files changed, 274325 insertions(+), 2256 deletions(-) create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_and.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_or.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_xor.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.block_sparse_mm.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.conj.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.conjugate.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.left_shift.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.metal.device_info.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.metal.reset_peak_memory.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.core.right_shift.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.optimizers.clip_grad_norm.rst create mode 100644 docs/build/html/_sources/python/_autosummary/mlx.utils.tree_reduce.rst create mode 100644 docs/build/html/allocator_8h.html create mode 100644 docs/build/html/allocator_8h_source.html create mode 100644 docs/build/html/annotated.html create mode 100644 docs/build/html/arange_8h.html create mode 100644 docs/build/html/arange_8h_source.html create mode 100644 docs/build/html/array_8h.html create mode 100644 docs/build/html/array_8h_source.html create mode 100644 docs/build/html/atomic_8h.html create mode 100644 docs/build/html/atomic_8h_source.html create mode 100644 docs/build/html/backend_2accelerate_2utils_8h.html create mode 100644 docs/build/html/backend_2accelerate_2utils_8h_source.html create mode 100644 docs/build/html/backend_2common_2ops_8h.html create mode 100644 docs/build/html/backend_2common_2ops_8h_source.html create mode 100644 docs/build/html/backend_2common_2utils_8h.html create mode 100644 docs/build/html/backend_2common_2utils_8h_source.html create mode 100644 docs/build/html/backend_2metal_2allocator_8h.html create mode 100644 docs/build/html/backend_2metal_2allocator_8h_source.html create mode 100644 docs/build/html/backend_2metal_2device_8h.html create mode 100644 docs/build/html/backend_2metal_2device_8h_source.html create mode 100644 docs/build/html/backend_2metal_2kernels_2bf16_8h.html create mode 100644 docs/build/html/backend_2metal_2kernels_2bf16_8h_source.html create mode 100644 docs/build/html/backend_2metal_2kernels_2complex_8h.html create mode 100644 docs/build/html/backend_2metal_2kernels_2complex_8h_source.html create mode 100644 docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h.html create mode 100644 docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h_source.html create mode 100644 docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h.html create mode 100644 docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h_source.html create mode 100644 docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h.html create mode 100644 docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h_source.html create mode 100644 docs/build/html/backend_2metal_2kernels_2steel_2utils_8h.html create mode 100644 docs/build/html/backend_2metal_2kernels_2steel_2utils_8h_source.html create mode 100644 docs/build/html/backend_2metal_2kernels_2utils_8h.html create mode 100644 docs/build/html/backend_2metal_2kernels_2utils_8h_source.html create mode 100644 docs/build/html/backend_2metal_2utils_8h.html create mode 100644 docs/build/html/backend_2metal_2utils_8h_source.html create mode 100644 docs/build/html/bc_s.png create mode 100644 docs/build/html/bc_sd.png create mode 100644 docs/build/html/bf16__math_8h.html create mode 100644 docs/build/html/bf16__math_8h_source.html create mode 100644 docs/build/html/binary__two_8h.html create mode 100644 docs/build/html/binary__two_8h_source.html create mode 100644 docs/build/html/class_m_p_s_1_1_kernel-members.html create mode 100644 docs/build/html/class_m_p_s_1_1_kernel.html create mode 100644 docs/build/html/class_m_p_s_1_1_kernel.png create mode 100644 docs/build/html/class_m_p_s_1_1_matrix-members.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix.png create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_descriptor.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_descriptor.png create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_multiplication-members.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_multiplication.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_multiplication.png create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html create mode 100644 docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.png create mode 100644 docs/build/html/class_m_p_s_1_1_vector-members.html create mode 100644 docs/build/html/class_m_p_s_1_1_vector.html create mode 100644 docs/build/html/class_m_p_s_1_1_vector.png create mode 100644 docs/build/html/class_m_p_s_1_1_vector_descriptor-members.html create mode 100644 docs/build/html/class_m_p_s_1_1_vector_descriptor.html create mode 100644 docs/build/html/class_m_p_s_1_1_vector_descriptor.png create mode 100644 docs/build/html/classes.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_abs-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_abs.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_abs.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_add-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_add.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_add.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_add_m_m-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_add_m_m.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_add_m_m.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arange-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arange.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arange.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_cos.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_cos.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_cosh.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_cosh.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_sin-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_sin.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_sin.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_sinh-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_sinh.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_sinh.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tan-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tan.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tan.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tan2.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tan2.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tanh.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arc_tanh.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_partition.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_partition.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_reduce-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_reduce.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_reduce.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_sort-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_sort.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_arg_sort.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_as_strided-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_as_strided.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_as_strided.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_as_type-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_as_type.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_as_type.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_bitwise_binary.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_broadcast-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_broadcast.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_broadcast.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_ceil-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_ceil.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_ceil.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_compiled-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_compiled.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_compiled.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_concatenate-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_concatenate.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_concatenate.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_conjugate-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_conjugate.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_conjugate.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_convolution-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_convolution.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_convolution.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_copy-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_copy.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_copy.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_cos-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_cos.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_cos.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_cosh-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_cosh.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_cosh.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_custom_v_j_p-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_depends-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_depends.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_depends.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_div_mod-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_div_mod.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_div_mod.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_divide-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_divide.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_divide.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_equal-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_equal.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_equal.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_erf-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_erf.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_erf.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_erf_inv.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_erf_inv.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_event-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_event.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_exp-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_exp.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_exp.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_expm1-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_expm1.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_expm1.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_f_f_t.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_f_f_t.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_floor-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_floor.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_floor.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_full-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_full.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_full.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_gather-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_gather.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_gather.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_greater-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_greater.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_greater.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_greater_equal-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_greater_equal.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_greater_equal.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_inverse-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_inverse.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_inverse.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_less-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_less.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_less.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_less_equal-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_less_equal.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_less_equal.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_load-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_load.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_load.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_log-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_log.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_log.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_log1p-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_log1p.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_log1p.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_log_add_exp.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_log_add_exp.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_and-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_and.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_and.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_not-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_not.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_not.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_or-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_or.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_logical_or.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_matmul-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_matmul.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_matmul.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_maximum-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_maximum.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_maximum.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_minimum-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_minimum.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_minimum.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_multiply-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_multiply.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_multiply.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_negative-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_negative.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_negative.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_not_equal-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_not_equal.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_not_equal.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_number_of_elements-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_number_of_elements.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_number_of_elements.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_pad-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_pad.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_pad.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_partition-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_partition.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_partition.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_power-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_power.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_power.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_primitive-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_primitive.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_primitive.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_q_r_f.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_q_r_f.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_quantized_matmul-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_quantized_matmul.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_random_bits-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_random_bits.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_random_bits.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_reduce-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_reduce.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_reduce.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_remainder-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_remainder.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_remainder.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_reshape-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_reshape.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_reshape.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_round-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_round.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_round.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_s_v_d.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_s_v_d.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_scan-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_scan.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_scan.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_scatter-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_scatter.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_scatter.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_select-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_select.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_select.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sigmoid.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sigmoid.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_sign-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sign.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sign.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_sin-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sin.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sin.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_sinh-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sinh.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sinh.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_slice-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_slice.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_slice.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_slice_update-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_slice_update.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_slice_update.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_softmax-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_softmax.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_softmax.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_sort-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sort.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sort.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_split-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_split.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_split.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_sqrt-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sqrt.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_sqrt.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_square-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_square.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_square.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_stop_gradient-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_stop_gradient.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_stop_gradient.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_subtract-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_subtract.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_subtract.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_tan-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_tan.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_tan.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_tanh-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_tanh.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_tanh.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_transpose-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_transpose.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_transpose.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_unary_primitive-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_unary_primitive.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_unary_primitive.png create mode 100644 docs/build/html/classmlx_1_1core_1_1_uniform-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_uniform.html create mode 100644 docs/build/html/classmlx_1_1core_1_1_uniform.png create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.html create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.png create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer.html create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.html create mode 100644 docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.png create mode 100644 docs/build/html/classmlx_1_1core_1_1array-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1array.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_custom-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.png create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.png create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.png create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.png create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.png create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.png create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.html create mode 100644 docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.png create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.png create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.png create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_reader-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_reader.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_reader.png create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_writer-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_writer.html create mode 100644 docs/build/html/classmlx_1_1core_1_1io_1_1_writer.png create mode 100644 docs/build/html/classmlx_1_1core_1_1metal_1_1_device-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1metal_1_1_device.html create mode 100644 docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.html create mode 100644 docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.png create mode 100644 docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence.html create mode 100644 docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler-members.html create mode 100644 docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dct1-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dct1.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dst1-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1_t__dst1.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1arr-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1arr.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1arr__info-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1arr__info.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1arr__info.png create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1cfftp-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1cfftp.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1cndarr-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1cndarr.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1cndarr.png create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1fftblue-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1fftblue.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1multi__iter-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1multi__iter.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1ndarr-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1ndarr.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1ndarr.png create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1rev__iter-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1rev__iter.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1rfftp-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1rfftp.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1simple__iter-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1simple__iter.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool-members.html create mode 100644 docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool.html create mode 100644 docs/build/html/clipboard.js create mode 100644 docs/build/html/closed.png create mode 100644 docs/build/html/common_2binary_8h.html create mode 100644 docs/build/html/common_2binary_8h_source.html create mode 100644 docs/build/html/common_2compiled__preamble_8h.html create mode 100644 docs/build/html/common_2compiled__preamble_8h_source.html create mode 100644 docs/build/html/common_2copy_8h.html create mode 100644 docs/build/html/common_2copy_8h_source.html create mode 100644 docs/build/html/common_2reduce_8h.html create mode 100644 docs/build/html/common_2reduce_8h_source.html create mode 100644 docs/build/html/common_2ternary_8h.html create mode 100644 docs/build/html/common_2ternary_8h_source.html create mode 100644 docs/build/html/common_2unary_8h.html create mode 100644 docs/build/html/common_2unary_8h_source.html create mode 100644 docs/build/html/compile_8h.html create mode 100644 docs/build/html/compile_8h_source.html create mode 100644 docs/build/html/compile__impl_8h.html create mode 100644 docs/build/html/compile__impl_8h_source.html create mode 100644 docs/build/html/compiled_8h.html create mode 100644 docs/build/html/compiled_8h_source.html create mode 100644 docs/build/html/conv_2loader_8h.html create mode 100644 docs/build/html/conv_2loader_8h_source.html create mode 100644 docs/build/html/conv_2params_8h.html create mode 100644 docs/build/html/conv_2params_8h_source.html create mode 100644 docs/build/html/conv_8h.html create mode 100644 docs/build/html/conv_8h_source.html create mode 100644 docs/build/html/cookie.js create mode 100644 docs/build/html/defines_8h.html create mode 100644 docs/build/html/defines_8h_source.html create mode 100644 docs/build/html/device_8h.html create mode 100644 docs/build/html/device_8h_source.html create mode 100644 docs/build/html/dir_1683daa6c50d5a1449f58a10604f9f12.html create mode 100644 docs/build/html/dir_1d446c9bd3c99228254c9484e0bc5c06.html create mode 100644 docs/build/html/dir_2193406f5b2eae6fc53753d8a9a80df3.html create mode 100644 docs/build/html/dir_47795aa8999234f6f402f7e89d34d08e.html create mode 100644 docs/build/html/dir_6768c99e6145fb9510ccdb40db8ede25.html create mode 100644 docs/build/html/dir_70a37effa88bcbd6b791977fa1e64356.html create mode 100644 docs/build/html/dir_76215a6c54e2b67053e723fc2395583c.html create mode 100644 docs/build/html/dir_86b95e7b1d0d6e25466bb9213752d32f.html create mode 100644 docs/build/html/dir_938ab0ecf10b8b860ff766c820f665fd.html create mode 100644 docs/build/html/dir_ad00dcd1517bfdbe01f68ec9b4eff877.html create mode 100644 docs/build/html/dir_ba4426224ef60f409462a2a12fa18f06.html create mode 100644 docs/build/html/dir_d0c977ea65824390717cdb7efc36c157.html create mode 100644 docs/build/html/dir_df9494e83ef22ae6150a0e080d9709ed.html create mode 100644 docs/build/html/dir_f149b24a1b5be11cd70151abe517e3f8.html create mode 100644 docs/build/html/dir_f60cd69d27fd3faa641c79056fff0e2d.html create mode 100644 docs/build/html/doc.svg create mode 100644 docs/build/html/docd.svg create mode 100644 docs/build/html/doxygen.css create mode 100644 docs/build/html/doxygen.svg create mode 100644 docs/build/html/doxygen_crawl.html create mode 100644 docs/build/html/dtype_8h.html create mode 100644 docs/build/html/dtype_8h_source.html create mode 100644 docs/build/html/dynsections.js create mode 100644 docs/build/html/erf_8h.html create mode 100644 docs/build/html/erf_8h_source.html create mode 100644 docs/build/html/event_8h.html create mode 100644 docs/build/html/event_8h_source.html create mode 100644 docs/build/html/expm1f_8h.html create mode 100644 docs/build/html/expm1f_8h_source.html create mode 100644 docs/build/html/fast_8h.html create mode 100644 docs/build/html/fast_8h_source.html create mode 100644 docs/build/html/fast__primitives_8h.html create mode 100644 docs/build/html/fast__primitives_8h_source.html create mode 100644 docs/build/html/fft_8h.html create mode 100644 docs/build/html/fft_8h_source.html create mode 100644 docs/build/html/files.html create mode 100644 docs/build/html/folderclosed.svg create mode 100644 docs/build/html/folderclosedd.svg create mode 100644 docs/build/html/folderopen.svg create mode 100644 docs/build/html/folderopend.svg create mode 100644 docs/build/html/fp16_8h.html create mode 100644 docs/build/html/fp16_8h_source.html create mode 100644 docs/build/html/functions.html create mode 100644 docs/build/html/functions_a.html create mode 100644 docs/build/html/functions_b.html create mode 100644 docs/build/html/functions_c.html create mode 100644 docs/build/html/functions_d.html create mode 100644 docs/build/html/functions_e.html create mode 100644 docs/build/html/functions_enum.html create mode 100644 docs/build/html/functions_eval.html create mode 100644 docs/build/html/functions_f.html create mode 100644 docs/build/html/functions_func.html create mode 100644 docs/build/html/functions_func_a.html create mode 100644 docs/build/html/functions_func_b.html create mode 100644 docs/build/html/functions_func_c.html create mode 100644 docs/build/html/functions_func_d.html create mode 100644 docs/build/html/functions_func_e.html create mode 100644 docs/build/html/functions_func_f.html create mode 100644 docs/build/html/functions_func_g.html create mode 100644 docs/build/html/functions_func_h.html create mode 100644 docs/build/html/functions_func_i.html create mode 100644 docs/build/html/functions_func_j.html create mode 100644 docs/build/html/functions_func_k.html create mode 100644 docs/build/html/functions_func_l.html create mode 100644 docs/build/html/functions_func_m.html create mode 100644 docs/build/html/functions_func_n.html create mode 100644 docs/build/html/functions_func_o.html create mode 100644 docs/build/html/functions_func_p.html create mode 100644 docs/build/html/functions_func_q.html create mode 100644 docs/build/html/functions_func_r.html create mode 100644 docs/build/html/functions_func_s.html create mode 100644 docs/build/html/functions_func_t.html create mode 100644 docs/build/html/functions_func_u.html create mode 100644 docs/build/html/functions_func_v.html create mode 100644 docs/build/html/functions_func_w.html create mode 100644 docs/build/html/functions_func_~.html create mode 100644 docs/build/html/functions_g.html create mode 100644 docs/build/html/functions_h.html create mode 100644 docs/build/html/functions_i.html create mode 100644 docs/build/html/functions_j.html create mode 100644 docs/build/html/functions_k.html create mode 100644 docs/build/html/functions_l.html create mode 100644 docs/build/html/functions_m.html create mode 100644 docs/build/html/functions_n.html create mode 100644 docs/build/html/functions_o.html create mode 100644 docs/build/html/functions_p.html create mode 100644 docs/build/html/functions_q.html create mode 100644 docs/build/html/functions_r.html create mode 100644 docs/build/html/functions_rela.html create mode 100644 docs/build/html/functions_s.html create mode 100644 docs/build/html/functions_t.html create mode 100644 docs/build/html/functions_type.html create mode 100644 docs/build/html/functions_u.html create mode 100644 docs/build/html/functions_v.html create mode 100644 docs/build/html/functions_vars.html create mode 100644 docs/build/html/functions_vars_b.html create mode 100644 docs/build/html/functions_vars_c.html create mode 100644 docs/build/html/functions_vars_d.html create mode 100644 docs/build/html/functions_vars_e.html create mode 100644 docs/build/html/functions_vars_f.html create mode 100644 docs/build/html/functions_vars_g.html create mode 100644 docs/build/html/functions_vars_i.html create mode 100644 docs/build/html/functions_vars_j.html create mode 100644 docs/build/html/functions_vars_k.html create mode 100644 docs/build/html/functions_vars_l.html create mode 100644 docs/build/html/functions_vars_m.html create mode 100644 docs/build/html/functions_vars_n.html create mode 100644 docs/build/html/functions_vars_o.html create mode 100644 docs/build/html/functions_vars_p.html create mode 100644 docs/build/html/functions_vars_q.html create mode 100644 docs/build/html/functions_vars_r.html create mode 100644 docs/build/html/functions_vars_s.html create mode 100644 docs/build/html/functions_vars_t.html create mode 100644 docs/build/html/functions_vars_v.html create mode 100644 docs/build/html/functions_vars_w.html create mode 100644 docs/build/html/functions_w.html create mode 100644 docs/build/html/functions_x.html create mode 100644 docs/build/html/functions_~.html create mode 100644 docs/build/html/gemm_2loader_8h.html create mode 100644 docs/build/html/gemm_2loader_8h_source.html create mode 100644 docs/build/html/gemm_2params_8h.html create mode 100644 docs/build/html/gemm_2params_8h_source.html create mode 100644 docs/build/html/gguf_8h.html create mode 100644 docs/build/html/gguf_8h_source.html create mode 100644 docs/build/html/globals.html create mode 100644 docs/build/html/globals_a.html create mode 100644 docs/build/html/globals_b.html create mode 100644 docs/build/html/globals_c.html create mode 100644 docs/build/html/globals_d.html create mode 100644 docs/build/html/globals_defs.html create mode 100644 docs/build/html/globals_e.html create mode 100644 docs/build/html/globals_f.html create mode 100644 docs/build/html/globals_func.html create mode 100644 docs/build/html/globals_func_c.html create mode 100644 docs/build/html/globals_func_e.html create mode 100644 docs/build/html/globals_func_f.html create mode 100644 docs/build/html/globals_func_g.html create mode 100644 docs/build/html/globals_func_l.html create mode 100644 docs/build/html/globals_func_m.html create mode 100644 docs/build/html/globals_func_o.html create mode 100644 docs/build/html/globals_func_s.html create mode 100644 docs/build/html/globals_g.html create mode 100644 docs/build/html/globals_h.html create mode 100644 docs/build/html/globals_i.html create mode 100644 docs/build/html/globals_l.html create mode 100644 docs/build/html/globals_m.html create mode 100644 docs/build/html/globals_n.html create mode 100644 docs/build/html/globals_o.html create mode 100644 docs/build/html/globals_p.html create mode 100644 docs/build/html/globals_r.html create mode 100644 docs/build/html/globals_s.html create mode 100644 docs/build/html/globals_type.html create mode 100644 docs/build/html/globals_u.html create mode 100644 docs/build/html/globals_vars.html create mode 100644 docs/build/html/graph__utils_8h.html create mode 100644 docs/build/html/graph__utils_8h_source.html create mode 100644 docs/build/html/group__ops.html create mode 100644 docs/build/html/half__types_8h.html create mode 100644 docs/build/html/half__types_8h_source.html create mode 100644 docs/build/html/hierarchy.html create mode 100644 docs/build/html/indexing_8h.html create mode 100644 docs/build/html/indexing_8h_source.html create mode 100644 docs/build/html/io_8h.html create mode 100644 docs/build/html/io_8h_source.html create mode 100644 docs/build/html/jquery.js create mode 100644 docs/build/html/kernels_2steel_2gemm_2gemm_8h.html create mode 100644 docs/build/html/kernels_2steel_2gemm_2gemm_8h_source.html create mode 100644 docs/build/html/lapack__helper_8h.html create mode 100644 docs/build/html/lapack__helper_8h_source.html create mode 100644 docs/build/html/linalg_8h.html create mode 100644 docs/build/html/linalg_8h_source.html create mode 100644 docs/build/html/load_8h.html create mode 100644 docs/build/html/load_8h_source.html create mode 100644 docs/build/html/loader__channel__l_8h.html create mode 100644 docs/build/html/loader__channel__l_8h_source.html create mode 100644 docs/build/html/loader__channel__n_8h.html create mode 100644 docs/build/html/loader__channel__n_8h_source.html create mode 100644 docs/build/html/loader__general_8h.html create mode 100644 docs/build/html/loader__general_8h_source.html create mode 100644 docs/build/html/matmul_8h.html create mode 100644 docs/build/html/matmul_8h_source.html create mode 100644 docs/build/html/menu.js create mode 100644 docs/build/html/menudata.js create mode 100644 docs/build/html/metal_2compiled__preamble_8h.html create mode 100644 docs/build/html/metal_2compiled__preamble_8h_source.html create mode 100644 docs/build/html/metal_2copy_8h.html create mode 100644 docs/build/html/metal_2copy_8h_source.html create mode 100644 docs/build/html/metal_2kernels_2binary_8h.html create mode 100644 docs/build/html/metal_2kernels_2binary_8h_source.html create mode 100644 docs/build/html/metal_2kernels_2compiled__preamble_8h.html create mode 100644 docs/build/html/metal_2kernels_2compiled__preamble_8h_source.html create mode 100644 docs/build/html/metal_2kernels_2ternary_8h.html create mode 100644 docs/build/html/metal_2kernels_2ternary_8h_source.html create mode 100644 docs/build/html/metal_2kernels_2unary_8h.html create mode 100644 docs/build/html/metal_2kernels_2unary_8h_source.html create mode 100644 docs/build/html/metal_2reduce_8h.html create mode 100644 docs/build/html/metal_2reduce_8h_source.html create mode 100644 docs/build/html/metal_8h.html create mode 100644 docs/build/html/metal_8h_source.html create mode 100644 docs/build/html/metal__impl_8h.html create mode 100644 docs/build/html/metal__impl_8h_source.html create mode 100644 docs/build/html/minus.svg create mode 100644 docs/build/html/minusd.svg create mode 100644 docs/build/html/mlx_8h.html create mode 100644 docs/build/html/mlx_8h_source.html create mode 100644 docs/build/html/mma_8h.html create mode 100644 docs/build/html/mma_8h_source.html create mode 100644 docs/build/html/mps_2gemm_8h.html create mode 100644 docs/build/html/mps_2gemm_8h_source.html create mode 100644 docs/build/html/namespace_m_p_s.html create mode 100644 docs/build/html/namespace_m_t_l.html create mode 100644 docs/build/html/namespace_m_t_l_1_1_private.html create mode 100644 docs/build/html/namespace_m_t_l_1_1_private_1_1_class.html create mode 100644 docs/build/html/namespace_m_t_l_1_1_private_1_1_selector.html create mode 100644 docs/build/html/namespacemembers.html create mode 100644 docs/build/html/namespacemembers_a.html create mode 100644 docs/build/html/namespacemembers_b.html create mode 100644 docs/build/html/namespacemembers_c.html create mode 100644 docs/build/html/namespacemembers_d.html create mode 100644 docs/build/html/namespacemembers_e.html create mode 100644 docs/build/html/namespacemembers_enum.html create mode 100644 docs/build/html/namespacemembers_eval.html create mode 100644 docs/build/html/namespacemembers_f.html create mode 100644 docs/build/html/namespacemembers_func.html create mode 100644 docs/build/html/namespacemembers_func_a.html create mode 100644 docs/build/html/namespacemembers_func_b.html create mode 100644 docs/build/html/namespacemembers_func_c.html create mode 100644 docs/build/html/namespacemembers_func_d.html create mode 100644 docs/build/html/namespacemembers_func_e.html create mode 100644 docs/build/html/namespacemembers_func_f.html create mode 100644 docs/build/html/namespacemembers_func_g.html create mode 100644 docs/build/html/namespacemembers_func_i.html create mode 100644 docs/build/html/namespacemembers_func_j.html create mode 100644 docs/build/html/namespacemembers_func_k.html create mode 100644 docs/build/html/namespacemembers_func_l.html create mode 100644 docs/build/html/namespacemembers_func_m.html create mode 100644 docs/build/html/namespacemembers_func_n.html create mode 100644 docs/build/html/namespacemembers_func_o.html create mode 100644 docs/build/html/namespacemembers_func_p.html create mode 100644 docs/build/html/namespacemembers_func_q.html create mode 100644 docs/build/html/namespacemembers_func_r.html create mode 100644 docs/build/html/namespacemembers_func_s.html create mode 100644 docs/build/html/namespacemembers_func_t.html create mode 100644 docs/build/html/namespacemembers_func_u.html create mode 100644 docs/build/html/namespacemembers_func_v.html create mode 100644 docs/build/html/namespacemembers_func_w.html create mode 100644 docs/build/html/namespacemembers_func_z.html create mode 100644 docs/build/html/namespacemembers_g.html create mode 100644 docs/build/html/namespacemembers_i.html create mode 100644 docs/build/html/namespacemembers_j.html create mode 100644 docs/build/html/namespacemembers_k.html create mode 100644 docs/build/html/namespacemembers_l.html create mode 100644 docs/build/html/namespacemembers_m.html create mode 100644 docs/build/html/namespacemembers_n.html create mode 100644 docs/build/html/namespacemembers_o.html create mode 100644 docs/build/html/namespacemembers_p.html create mode 100644 docs/build/html/namespacemembers_q.html create mode 100644 docs/build/html/namespacemembers_r.html create mode 100644 docs/build/html/namespacemembers_s.html create mode 100644 docs/build/html/namespacemembers_t.html create mode 100644 docs/build/html/namespacemembers_type.html create mode 100644 docs/build/html/namespacemembers_u.html create mode 100644 docs/build/html/namespacemembers_v.html create mode 100644 docs/build/html/namespacemembers_vars.html create mode 100644 docs/build/html/namespacemembers_w.html create mode 100644 docs/build/html/namespacemembers_z.html create mode 100644 docs/build/html/namespacemetal.html create mode 100644 docs/build/html/namespacemetal_1_1fast.html create mode 100644 docs/build/html/namespacemetal_1_1precise.html create mode 100644 docs/build/html/namespacemlx.html create mode 100644 docs/build/html/namespacemlx_1_1core.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1allocator.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1detail.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1fast.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1fft.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1io.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1linalg.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1metal.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1random.html create mode 100644 docs/build/html/namespacemlx_1_1core_1_1scheduler.html create mode 100644 docs/build/html/namespacemlx_1_1steel.html create mode 100644 docs/build/html/namespacepocketfft.html create mode 100644 docs/build/html/namespacepocketfft_1_1detail.html create mode 100644 docs/build/html/namespacepocketfft_1_1detail_1_1threading.html create mode 100644 docs/build/html/namespaces.html create mode 100644 docs/build/html/nav_f.png create mode 100644 docs/build/html/nav_fd.png create mode 100644 docs/build/html/nav_g.png create mode 100644 docs/build/html/nav_h.png create mode 100644 docs/build/html/nav_hd.png create mode 100644 docs/build/html/open.png create mode 100644 docs/build/html/ops_8h.html create mode 100644 docs/build/html/ops_8h_source.html create mode 100644 docs/build/html/plus.svg create mode 100644 docs/build/html/plusd.svg create mode 100644 docs/build/html/pocketfft_8h.html create mode 100644 docs/build/html/pocketfft_8h_source.html create mode 100644 docs/build/html/primitives_8h.html create mode 100644 docs/build/html/primitives_8h_source.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.arctan2.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.bitwise_and.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.bitwise_or.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.bitwise_xor.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.block_sparse_mm.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.conj.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.conjugate.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.left_shift.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.metal.device_info.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.metal.reset_peak_memory.html create mode 100644 docs/build/html/python/_autosummary/mlx.core.right_shift.html create mode 100644 docs/build/html/python/_autosummary/mlx.optimizers.clip_grad_norm.html create mode 100644 docs/build/html/python/_autosummary/mlx.utils.tree_reduce.html create mode 100644 docs/build/html/random_8h.html create mode 100644 docs/build/html/random_8h_source.html create mode 100644 docs/build/html/reduce__inst_8h.html create mode 100644 docs/build/html/reduce__inst_8h_source.html create mode 100644 docs/build/html/scaled__dot__product__attention__params_8h.html create mode 100644 docs/build/html/scaled__dot__product__attention__params_8h_source.html create mode 100644 docs/build/html/scheduler_8h.html create mode 100644 docs/build/html/scheduler_8h_source.html create mode 100644 docs/build/html/search/all_0.js create mode 100644 docs/build/html/search/all_1.js create mode 100644 docs/build/html/search/all_10.js create mode 100644 docs/build/html/search/all_11.js create mode 100644 docs/build/html/search/all_12.js create mode 100644 docs/build/html/search/all_13.js create mode 100644 docs/build/html/search/all_14.js create mode 100644 docs/build/html/search/all_15.js create mode 100644 docs/build/html/search/all_16.js create mode 100644 docs/build/html/search/all_17.js create mode 100644 docs/build/html/search/all_18.js create mode 100644 docs/build/html/search/all_19.js create mode 100644 docs/build/html/search/all_1a.js create mode 100644 docs/build/html/search/all_2.js create mode 100644 docs/build/html/search/all_3.js create mode 100644 docs/build/html/search/all_4.js create mode 100644 docs/build/html/search/all_5.js create mode 100644 docs/build/html/search/all_6.js create mode 100644 docs/build/html/search/all_7.js create mode 100644 docs/build/html/search/all_8.js create mode 100644 docs/build/html/search/all_9.js create mode 100644 docs/build/html/search/all_a.js create mode 100644 docs/build/html/search/all_b.js create mode 100644 docs/build/html/search/all_c.js create mode 100644 docs/build/html/search/all_d.js create mode 100644 docs/build/html/search/all_e.js create mode 100644 docs/build/html/search/all_f.js create mode 100644 docs/build/html/search/classes_0.js create mode 100644 docs/build/html/search/classes_1.js create mode 100644 docs/build/html/search/classes_10.js create mode 100644 docs/build/html/search/classes_11.js create mode 100644 docs/build/html/search/classes_12.js create mode 100644 docs/build/html/search/classes_13.js create mode 100644 docs/build/html/search/classes_14.js create mode 100644 docs/build/html/search/classes_15.js create mode 100644 docs/build/html/search/classes_2.js create mode 100644 docs/build/html/search/classes_3.js create mode 100644 docs/build/html/search/classes_4.js create mode 100644 docs/build/html/search/classes_5.js create mode 100644 docs/build/html/search/classes_6.js create mode 100644 docs/build/html/search/classes_7.js create mode 100644 docs/build/html/search/classes_8.js create mode 100644 docs/build/html/search/classes_9.js create mode 100644 docs/build/html/search/classes_a.js create mode 100644 docs/build/html/search/classes_b.js create mode 100644 docs/build/html/search/classes_c.js create mode 100644 docs/build/html/search/classes_d.js create mode 100644 docs/build/html/search/classes_e.js create mode 100644 docs/build/html/search/classes_f.js create mode 100644 docs/build/html/search/close.svg create mode 100644 docs/build/html/search/defines_0.js create mode 100644 docs/build/html/search/defines_1.js create mode 100644 docs/build/html/search/defines_2.js create mode 100644 docs/build/html/search/defines_3.js create mode 100644 docs/build/html/search/defines_4.js create mode 100644 docs/build/html/search/defines_5.js create mode 100644 docs/build/html/search/defines_6.js create mode 100644 docs/build/html/search/defines_7.js create mode 100644 docs/build/html/search/defines_8.js create mode 100644 docs/build/html/search/defines_9.js create mode 100644 docs/build/html/search/defines_a.js create mode 100644 docs/build/html/search/defines_b.js create mode 100644 docs/build/html/search/enums_0.js create mode 100644 docs/build/html/search/enums_1.js create mode 100644 docs/build/html/search/enums_2.js create mode 100644 docs/build/html/search/enums_3.js create mode 100644 docs/build/html/search/enums_4.js create mode 100644 docs/build/html/search/enums_5.js create mode 100644 docs/build/html/search/enums_6.js create mode 100644 docs/build/html/search/enums_7.js create mode 100644 docs/build/html/search/enumvalues_0.js create mode 100644 docs/build/html/search/enumvalues_1.js create mode 100644 docs/build/html/search/enumvalues_10.js create mode 100644 docs/build/html/search/enumvalues_11.js create mode 100644 docs/build/html/search/enumvalues_12.js create mode 100644 docs/build/html/search/enumvalues_2.js create mode 100644 docs/build/html/search/enumvalues_3.js create mode 100644 docs/build/html/search/enumvalues_4.js create mode 100644 docs/build/html/search/enumvalues_5.js create mode 100644 docs/build/html/search/enumvalues_6.js create mode 100644 docs/build/html/search/enumvalues_7.js create mode 100644 docs/build/html/search/enumvalues_8.js create mode 100644 docs/build/html/search/enumvalues_9.js create mode 100644 docs/build/html/search/enumvalues_a.js create mode 100644 docs/build/html/search/enumvalues_b.js create mode 100644 docs/build/html/search/enumvalues_c.js create mode 100644 docs/build/html/search/enumvalues_d.js create mode 100644 docs/build/html/search/enumvalues_e.js create mode 100644 docs/build/html/search/enumvalues_f.js create mode 100644 docs/build/html/search/files_0.js create mode 100644 docs/build/html/search/files_1.js create mode 100644 docs/build/html/search/files_10.js create mode 100644 docs/build/html/search/files_2.js create mode 100644 docs/build/html/search/files_3.js create mode 100644 docs/build/html/search/files_4.js create mode 100644 docs/build/html/search/files_5.js create mode 100644 docs/build/html/search/files_6.js create mode 100644 docs/build/html/search/files_7.js create mode 100644 docs/build/html/search/files_8.js create mode 100644 docs/build/html/search/files_9.js create mode 100644 docs/build/html/search/files_a.js create mode 100644 docs/build/html/search/files_b.js create mode 100644 docs/build/html/search/files_c.js create mode 100644 docs/build/html/search/files_d.js create mode 100644 docs/build/html/search/files_e.js create mode 100644 docs/build/html/search/files_f.js create mode 100644 docs/build/html/search/functions_0.js create mode 100644 docs/build/html/search/functions_1.js create mode 100644 docs/build/html/search/functions_10.js create mode 100644 docs/build/html/search/functions_11.js create mode 100644 docs/build/html/search/functions_12.js create mode 100644 docs/build/html/search/functions_13.js create mode 100644 docs/build/html/search/functions_14.js create mode 100644 docs/build/html/search/functions_15.js create mode 100644 docs/build/html/search/functions_16.js create mode 100644 docs/build/html/search/functions_17.js create mode 100644 docs/build/html/search/functions_18.js create mode 100644 docs/build/html/search/functions_19.js create mode 100644 docs/build/html/search/functions_2.js create mode 100644 docs/build/html/search/functions_3.js create mode 100644 docs/build/html/search/functions_4.js create mode 100644 docs/build/html/search/functions_5.js create mode 100644 docs/build/html/search/functions_6.js create mode 100644 docs/build/html/search/functions_7.js create mode 100644 docs/build/html/search/functions_8.js create mode 100644 docs/build/html/search/functions_9.js create mode 100644 docs/build/html/search/functions_a.js create mode 100644 docs/build/html/search/functions_b.js create mode 100644 docs/build/html/search/functions_c.js create mode 100644 docs/build/html/search/functions_d.js create mode 100644 docs/build/html/search/functions_e.js create mode 100644 docs/build/html/search/functions_f.js create mode 100644 docs/build/html/search/groups_0.js create mode 100644 docs/build/html/search/groups_1.js create mode 100644 docs/build/html/search/groups_2.js create mode 100644 docs/build/html/search/mag.svg create mode 100644 docs/build/html/search/mag_d.svg create mode 100644 docs/build/html/search/mag_sel.svg create mode 100644 docs/build/html/search/mag_seld.svg create mode 100644 docs/build/html/search/namespaces_0.js create mode 100644 docs/build/html/search/namespaces_1.js create mode 100644 docs/build/html/search/related_0.js create mode 100644 docs/build/html/search/related_1.js create mode 100644 docs/build/html/search/search.css create mode 100644 docs/build/html/search/search.js create mode 100644 docs/build/html/search/searchdata.js create mode 100644 docs/build/html/search/typedefs_0.js create mode 100644 docs/build/html/search/typedefs_1.js create mode 100644 docs/build/html/search/typedefs_2.js create mode 100644 docs/build/html/search/typedefs_3.js create mode 100644 docs/build/html/search/typedefs_4.js create mode 100644 docs/build/html/search/typedefs_5.js create mode 100644 docs/build/html/search/typedefs_6.js create mode 100644 docs/build/html/search/typedefs_7.js create mode 100644 docs/build/html/search/typedefs_8.js create mode 100644 docs/build/html/search/typedefs_9.js create mode 100644 docs/build/html/search/typedefs_a.js create mode 100644 docs/build/html/search/typedefs_b.js create mode 100644 docs/build/html/search/typedefs_c.js create mode 100644 docs/build/html/search/variables_0.js create mode 100644 docs/build/html/search/variables_1.js create mode 100644 docs/build/html/search/variables_10.js create mode 100644 docs/build/html/search/variables_11.js create mode 100644 docs/build/html/search/variables_12.js create mode 100644 docs/build/html/search/variables_13.js create mode 100644 docs/build/html/search/variables_14.js create mode 100644 docs/build/html/search/variables_15.js create mode 100644 docs/build/html/search/variables_2.js create mode 100644 docs/build/html/search/variables_3.js create mode 100644 docs/build/html/search/variables_4.js create mode 100644 docs/build/html/search/variables_5.js create mode 100644 docs/build/html/search/variables_6.js create mode 100644 docs/build/html/search/variables_7.js create mode 100644 docs/build/html/search/variables_8.js create mode 100644 docs/build/html/search/variables_9.js create mode 100644 docs/build/html/search/variables_a.js create mode 100644 docs/build/html/search/variables_b.js create mode 100644 docs/build/html/search/variables_c.js create mode 100644 docs/build/html/search/variables_d.js create mode 100644 docs/build/html/search/variables_e.js create mode 100644 docs/build/html/search/variables_f.js create mode 100644 docs/build/html/splitbar.png create mode 100644 docs/build/html/splitbard.png create mode 100644 docs/build/html/stream_8h.html create mode 100644 docs/build/html/stream_8h_source.html create mode 100644 docs/build/html/struct___m_l_x___b_float16-members.html create mode 100644 docs/build/html/struct___m_l_x___b_float16.html create mode 100644 docs/build/html/struct___m_l_x___b_float16_1_1bits__to__bfloat__struct.html create mode 100644 docs/build/html/struct_abs-members.html create mode 100644 docs/build/html/struct_abs.html create mode 100644 docs/build/html/struct_add-members.html create mode 100644 docs/build/html/struct_add.html create mode 100644 docs/build/html/struct_and-members.html create mode 100644 docs/build/html/struct_and.html create mode 100644 docs/build/html/struct_arc_cos-members.html create mode 100644 docs/build/html/struct_arc_cos.html create mode 100644 docs/build/html/struct_arc_cosh-members.html create mode 100644 docs/build/html/struct_arc_cosh.html create mode 100644 docs/build/html/struct_arc_sin-members.html create mode 100644 docs/build/html/struct_arc_sin.html create mode 100644 docs/build/html/struct_arc_sinh-members.html create mode 100644 docs/build/html/struct_arc_sinh.html create mode 100644 docs/build/html/struct_arc_tan-members.html create mode 100644 docs/build/html/struct_arc_tan.html create mode 100644 docs/build/html/struct_arc_tan2-members.html create mode 100644 docs/build/html/struct_arc_tan2.html create mode 100644 docs/build/html/struct_arc_tanh-members.html create mode 100644 docs/build/html/struct_arc_tanh.html create mode 100644 docs/build/html/struct_bitwise_and-members.html create mode 100644 docs/build/html/struct_bitwise_and.html create mode 100644 docs/build/html/struct_bitwise_or-members.html create mode 100644 docs/build/html/struct_bitwise_or.html create mode 100644 docs/build/html/struct_bitwise_xor-members.html create mode 100644 docs/build/html/struct_bitwise_xor.html create mode 100644 docs/build/html/struct_ceil-members.html create mode 100644 docs/build/html/struct_ceil.html create mode 100644 docs/build/html/struct_conjugate-members.html create mode 100644 docs/build/html/struct_conjugate.html create mode 100644 docs/build/html/struct_cos-members.html create mode 100644 docs/build/html/struct_cos.html create mode 100644 docs/build/html/struct_cosh-members.html create mode 100644 docs/build/html/struct_cosh.html create mode 100644 docs/build/html/struct_divide-members.html create mode 100644 docs/build/html/struct_divide.html create mode 100644 docs/build/html/struct_equal-members.html create mode 100644 docs/build/html/struct_equal.html create mode 100644 docs/build/html/struct_erf-members.html create mode 100644 docs/build/html/struct_erf.html create mode 100644 docs/build/html/struct_erf_inv-members.html create mode 100644 docs/build/html/struct_erf_inv.html create mode 100644 docs/build/html/struct_exp-members.html create mode 100644 docs/build/html/struct_exp.html create mode 100644 docs/build/html/struct_expm1-members.html create mode 100644 docs/build/html/struct_expm1.html create mode 100644 docs/build/html/struct_floor-members.html create mode 100644 docs/build/html/struct_floor.html create mode 100644 docs/build/html/struct_greater-members.html create mode 100644 docs/build/html/struct_greater.html create mode 100644 docs/build/html/struct_greater_equal-members.html create mode 100644 docs/build/html/struct_greater_equal.html create mode 100644 docs/build/html/struct_indices-members.html create mode 100644 docs/build/html/struct_indices.html create mode 100644 docs/build/html/struct_left_shift-members.html create mode 100644 docs/build/html/struct_left_shift.html create mode 100644 docs/build/html/struct_less-members.html create mode 100644 docs/build/html/struct_less.html create mode 100644 docs/build/html/struct_less_equal-members.html create mode 100644 docs/build/html/struct_less_equal.html create mode 100644 docs/build/html/struct_limits-members.html create mode 100644 docs/build/html/struct_limits.html create mode 100644 docs/build/html/struct_limits_3_01bfloat16__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01bfloat16__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01bool_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01bool_01_4.html create mode 100644 docs/build/html/struct_limits_3_01float_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01float_01_4.html create mode 100644 docs/build/html/struct_limits_3_01half_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01half_01_4.html create mode 100644 docs/build/html/struct_limits_3_01int16__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01int16__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01int32__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01int32__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01int64__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01int64__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01int8__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01int8__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01uint16__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01uint16__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01uint32__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01uint32__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01uint64__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01uint64__t_01_4.html create mode 100644 docs/build/html/struct_limits_3_01uint8__t_01_4-members.html create mode 100644 docs/build/html/struct_limits_3_01uint8__t_01_4.html create mode 100644 docs/build/html/struct_log-members.html create mode 100644 docs/build/html/struct_log.html create mode 100644 docs/build/html/struct_log10-members.html create mode 100644 docs/build/html/struct_log10.html create mode 100644 docs/build/html/struct_log1p-members.html create mode 100644 docs/build/html/struct_log1p.html create mode 100644 docs/build/html/struct_log2-members.html create mode 100644 docs/build/html/struct_log2.html create mode 100644 docs/build/html/struct_log_add_exp-members.html create mode 100644 docs/build/html/struct_log_add_exp.html create mode 100644 docs/build/html/struct_logical_and-members.html create mode 100644 docs/build/html/struct_logical_and.html create mode 100644 docs/build/html/struct_logical_not-members.html create mode 100644 docs/build/html/struct_logical_not.html create mode 100644 docs/build/html/struct_logical_or-members.html create mode 100644 docs/build/html/struct_logical_or.html create mode 100644 docs/build/html/struct_m_l_x_conv_params-members.html create mode 100644 docs/build/html/struct_m_l_x_conv_params.html create mode 100644 docs/build/html/struct_m_l_x_scaled_dot_product_attention_params-members.html create mode 100644 docs/build/html/struct_m_l_x_scaled_dot_product_attention_params.html create mode 100644 docs/build/html/struct_max-members.html create mode 100644 docs/build/html/struct_max.html create mode 100644 docs/build/html/struct_maximum-members.html create mode 100644 docs/build/html/struct_maximum.html create mode 100644 docs/build/html/struct_min-members.html create mode 100644 docs/build/html/struct_min.html create mode 100644 docs/build/html/struct_minimum-members.html create mode 100644 docs/build/html/struct_minimum.html create mode 100644 docs/build/html/struct_multiply-members.html create mode 100644 docs/build/html/struct_multiply.html create mode 100644 docs/build/html/struct_na_n_equal-members.html create mode 100644 docs/build/html/struct_na_n_equal.html create mode 100644 docs/build/html/struct_negative-members.html create mode 100644 docs/build/html/struct_negative.html create mode 100644 docs/build/html/struct_none-members.html create mode 100644 docs/build/html/struct_none.html create mode 100644 docs/build/html/struct_not_equal-members.html create mode 100644 docs/build/html/struct_not_equal.html create mode 100644 docs/build/html/struct_or-members.html create mode 100644 docs/build/html/struct_or.html create mode 100644 docs/build/html/struct_power-members.html create mode 100644 docs/build/html/struct_power.html create mode 100644 docs/build/html/struct_prod-members.html create mode 100644 docs/build/html/struct_prod.html create mode 100644 docs/build/html/struct_remainder-members.html create mode 100644 docs/build/html/struct_remainder.html create mode 100644 docs/build/html/struct_right_shift-members.html create mode 100644 docs/build/html/struct_right_shift.html create mode 100644 docs/build/html/struct_round-members.html create mode 100644 docs/build/html/struct_round.html create mode 100644 docs/build/html/struct_rsqrt-members.html create mode 100644 docs/build/html/struct_rsqrt.html create mode 100644 docs/build/html/struct_select-members.html create mode 100644 docs/build/html/struct_select.html create mode 100644 docs/build/html/struct_sigmoid-members.html create mode 100644 docs/build/html/struct_sigmoid.html create mode 100644 docs/build/html/struct_sign-members.html create mode 100644 docs/build/html/struct_sign.html create mode 100644 docs/build/html/struct_sin-members.html create mode 100644 docs/build/html/struct_sin.html create mode 100644 docs/build/html/struct_sinh-members.html create mode 100644 docs/build/html/struct_sinh.html create mode 100644 docs/build/html/struct_sqrt-members.html create mode 100644 docs/build/html/struct_sqrt.html create mode 100644 docs/build/html/struct_square-members.html create mode 100644 docs/build/html/struct_square.html create mode 100644 docs/build/html/struct_subtract-members.html create mode 100644 docs/build/html/struct_subtract.html create mode 100644 docs/build/html/struct_sum-members.html create mode 100644 docs/build/html/struct_sum.html create mode 100644 docs/build/html/struct_tan-members.html create mode 100644 docs/build/html/struct_tan.html create mode 100644 docs/build/html/struct_tanh-members.html create mode 100644 docs/build/html/struct_tanh.html create mode 100644 docs/build/html/structcomplex64__t-members.html create mode 100644 docs/build/html/structcomplex64__t.html create mode 100644 docs/build/html/structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4-members.html create mode 100644 docs/build/html/structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html create mode 100644 docs/build/html/structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.png create mode 100644 docs/build/html/structmlx_1_1core_1_1___m_l_x___b_float16-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1___m_l_x___b_float16.html create mode 100644 docs/build/html/structmlx_1_1core_1_1___m_l_x___float16-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1___m_l_x___float16.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_device-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_device.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_dtype-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_dtype.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_node_namer-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_node_namer.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_print_formatter-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_print_formatter.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_reduction_plan-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_reduction_plan.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_stream-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_stream.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_stream_context-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_stream_context.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_type_to_dtype-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1_type_to_dtype.html create mode 100644 docs/build/html/structmlx_1_1core_1_1array_1_1_array_iterator-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1array_1_1_array_iterator.html create mode 100644 docs/build/html/structmlx_1_1core_1_1array_1_1_data-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1array_1_1_data.html create mode 100644 docs/build/html/structmlx_1_1core_1_1array_1_1_flags-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1array_1_1_flags.html create mode 100644 docs/build/html/structmlx_1_1core_1_1complex128__t-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1complex128__t.html create mode 100644 docs/build/html/structmlx_1_1core_1_1complex128__t.png create mode 100644 docs/build/html/structmlx_1_1core_1_1complex64__t-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1complex64__t.html create mode 100644 docs/build/html/structmlx_1_1core_1_1complex64__t.png create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_abs-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_abs.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_add-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_add.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_cos-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_cos.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_cosh-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_cosh.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_sin-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_sin.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_sinh-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_sinh.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_tan-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_tan.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_tan2-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_tan2.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_tanh-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_arc_tanh.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_bitwise_and-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_bitwise_and.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_bitwise_or-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_bitwise_or.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_bitwise_xor-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_bitwise_xor.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_ceil-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_ceil.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_conjugate-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_conjugate.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_cos-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_cos.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_cosh-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_cosh.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_divide-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_divide.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_equal-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_equal.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_erf-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_erf.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_erf_inv-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_erf_inv.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_exp-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_exp.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_expm1-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_expm1.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_floor-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_floor.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_greater-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_greater.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_greater_equal-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_greater_equal.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_in_tracing-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_in_tracing.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_left_shift-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_left_shift.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_less-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_less.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_less_equal-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_less_equal.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log10-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log10.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log1p-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log1p.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log2-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log2.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log_add_exp-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_log_add_exp.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_logical_and-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_logical_and.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_logical_not-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_logical_not.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_logical_or-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_logical_or.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_maximum-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_maximum.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_minimum-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_minimum.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_multiply-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_multiply.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_na_n_equal-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_na_n_equal.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_negative-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_negative.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_not_equal-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_not_equal.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_power-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_power.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_remainder-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_remainder.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_right_shift-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_right_shift.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_round-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_round.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_rsqrt-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_rsqrt.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_select-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_select.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sigmoid-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sigmoid.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sign-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sign.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sin-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sin.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sinh-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sinh.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sqrt-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_sqrt.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_square-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_square.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_subtract-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_subtract.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_tan-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_tan.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_tanh-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1detail_1_1_tanh.html create mode 100644 docs/build/html/structmlx_1_1core_1_1metal_1_1_command_encoder-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1metal_1_1_command_encoder.html create mode 100644 docs/build/html/structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html create mode 100644 docs/build/html/structmlx_1_1core_1_1scheduler_1_1_stream_thread-members.html create mode 100644 docs/build/html/structmlx_1_1core_1_1scheduler_1_1_stream_thread.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_accum_helper-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_accum_helper.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_loader-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_loader.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_loader_1_1_read_vector-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_loader_1_1_read_vector.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_m_m_a-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_m_m_a.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_swizzle-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_block_swizzle.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_011_01_4-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_011_01_4.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_012_01_4-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_012_01_4.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_013_01_4-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_013_01_4.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_014_01_4-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_channel_helper_3_014_01_4.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_general_base_info-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_general_base_info.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_general_jump_params-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_general_jump_params.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_general-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_weight_block_loader-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_add_m_m_params-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_add_m_m_params.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_kernel-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_kernel.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_params-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_params.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_spilt_k_params-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_g_e_m_m_spilt_k_params.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_implicit_gemm_conv2_d_params-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_implicit_gemm_conv2_d_params.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_loop_alignment.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_transform_add-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_transform_add.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_transform_axpby-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_transform_axpby.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_transform_none-members.html create mode 100644 docs/build/html/structmlx_1_1steel_1_1_transform_none.html create mode 100644 docs/build/html/structmlx__atomic-members.html create mode 100644 docs/build/html/structmlx__atomic.html create mode 100644 docs/build/html/structmlx__atomic_3_01_t_00_01enable__if__t_3_01is__metal__atomic_3_01_t_01_4_01_4_01_4-members.html create mode 100644 docs/build/html/structmlx__atomic_3_01_t_00_01enable__if__t_3_01is__metal__atomic_3_01_t_01_4_01_4_01_4.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_c2_c-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_c2_c.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_dcst-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_dcst.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_hartley-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_hartley.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_r2_r-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_exec_r2_r.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_v_l_e_n-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_v_l_e_n.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1_v_t_y_p_e.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1add__vec-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1add__vec.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1add__vec_3_01cmplx_3_01_t_01_4_01_4-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1add__vec_3_01cmplx_3_01_t_01_4_01_4.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1cmplx-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1cmplx.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1threading_1_1aligned__allocator-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1threading_1_1aligned__allocator.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1util-members.html create mode 100644 docs/build/html/structpocketfft_1_1detail_1_1util.html create mode 100644 docs/build/html/sync_off.png create mode 100644 docs/build/html/sync_on.png create mode 100644 docs/build/html/tab_a.png create mode 100644 docs/build/html/tab_ad.png create mode 100644 docs/build/html/tab_b.png create mode 100644 docs/build/html/tab_bd.png create mode 100644 docs/build/html/tab_h.png create mode 100644 docs/build/html/tab_hd.png create mode 100644 docs/build/html/tab_s.png create mode 100644 docs/build/html/tab_sd.png create mode 100644 docs/build/html/tabs.css create mode 100644 docs/build/html/threefry_8h.html create mode 100644 docs/build/html/threefry_8h_source.html create mode 100644 docs/build/html/topics.html create mode 100644 docs/build/html/transforms_8h.html create mode 100644 docs/build/html/transforms_8h_source.html create mode 100644 docs/build/html/transforms__impl_8h.html create mode 100644 docs/build/html/transforms__impl_8h_source.html create mode 100644 docs/build/html/types_2bf16_8h.html create mode 100644 docs/build/html/types_2bf16_8h_source.html create mode 100644 docs/build/html/types_2complex_8h.html create mode 100644 docs/build/html/types_2complex_8h_source.html create mode 100644 docs/build/html/unionbool4__or__uint-members.html create mode 100644 docs/build/html/unionbool4__or__uint.html create mode 100644 docs/build/html/unionmlx_1_1core_1_1detail_1_1_int_or_float-members.html create mode 100644 docs/build/html/unionmlx_1_1core_1_1detail_1_1_int_or_float.html create mode 100644 docs/build/html/utils_8h.html create mode 100644 docs/build/html/utils_8h_source.html diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 5a4af442a..4cd5f8315 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 0c08faf7a4a5981ee1e4c3cab57ef3b9 +config: 6d31d3d7850f7f8959377483b35af018 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_sources/cpp/ops.rst b/docs/build/html/_sources/cpp/ops.rst index 4d2d1404e..009a10b1e 100644 --- a/docs/build/html/_sources/cpp/ops.rst +++ b/docs/build/html/_sources/cpp/ops.rst @@ -3,4 +3,5 @@ Operations ========== - +.. doxygengroup:: ops + :content-only: diff --git a/docs/build/html/_sources/install.rst b/docs/build/html/_sources/install.rst index f34db7270..252b234e6 100644 --- a/docs/build/html/_sources/install.rst +++ b/docs/build/html/_sources/install.rst @@ -157,7 +157,10 @@ should point to the path to the built metal library. - OFF * - MLX_METAL_DEBUG - OFF - + * - MLX_BUILD_SAFETENSORS + - ON + * - MLX_BUILD_GGUF + - ON .. note:: diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst new file mode 100644 index 000000000..2039f68f0 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.arctan2.rst @@ -0,0 +1,6 @@ +mlx.core.arctan2 +================ + +.. currentmodule:: mlx.core + +.. autofunction:: arctan2 \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst index 9ea269f2a..e845b3cf8 100644 --- a/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.array.rst @@ -20,6 +20,7 @@ ~array.argmax ~array.argmin ~array.astype + ~array.conj ~array.cos ~array.cummax ~array.cummin diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_and.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_and.rst new file mode 100644 index 000000000..6b8497e5c --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_and.rst @@ -0,0 +1,6 @@ +mlx.core.bitwise\_and +===================== + +.. currentmodule:: mlx.core + +.. autofunction:: bitwise_and \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_or.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_or.rst new file mode 100644 index 000000000..15eb14604 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_or.rst @@ -0,0 +1,6 @@ +mlx.core.bitwise\_or +==================== + +.. currentmodule:: mlx.core + +.. autofunction:: bitwise_or \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_xor.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_xor.rst new file mode 100644 index 000000000..ae41e5f49 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.bitwise_xor.rst @@ -0,0 +1,6 @@ +mlx.core.bitwise\_xor +===================== + +.. currentmodule:: mlx.core + +.. autofunction:: bitwise_xor \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.block_sparse_mm.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.block_sparse_mm.rst new file mode 100644 index 000000000..72a9dd120 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.block_sparse_mm.rst @@ -0,0 +1,6 @@ +mlx.core.block\_sparse\_mm +========================== + +.. currentmodule:: mlx.core + +.. autofunction:: block_sparse_mm \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.conj.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.conj.rst new file mode 100644 index 000000000..f1dd8954d --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.conj.rst @@ -0,0 +1,6 @@ +mlx.core.conj +============= + +.. currentmodule:: mlx.core + +.. autofunction:: conj \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.conjugate.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.conjugate.rst new file mode 100644 index 000000000..3d3e20560 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.conjugate.rst @@ -0,0 +1,6 @@ +mlx.core.conjugate +================== + +.. currentmodule:: mlx.core + +.. autofunction:: conjugate \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.left_shift.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.left_shift.rst new file mode 100644 index 000000000..a99502501 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.left_shift.rst @@ -0,0 +1,6 @@ +mlx.core.left\_shift +==================== + +.. currentmodule:: mlx.core + +.. autofunction:: left_shift \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.metal.device_info.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.device_info.rst new file mode 100644 index 000000000..1c914a29a --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.device_info.rst @@ -0,0 +1,6 @@ +mlx.core.metal.device\_info +=========================== + +.. currentmodule:: mlx.core.metal + +.. autofunction:: device_info \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.metal.reset_peak_memory.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.reset_peak_memory.rst new file mode 100644 index 000000000..4bdbd144a --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.metal.reset_peak_memory.rst @@ -0,0 +1,6 @@ +mlx.core.metal.reset\_peak\_memory +================================== + +.. currentmodule:: mlx.core.metal + +.. autofunction:: reset_peak_memory \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.core.right_shift.rst b/docs/build/html/_sources/python/_autosummary/mlx.core.right_shift.rst new file mode 100644 index 000000000..471b61b95 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.core.right_shift.rst @@ -0,0 +1,6 @@ +mlx.core.right\_shift +===================== + +.. currentmodule:: mlx.core + +.. autofunction:: right_shift \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.optimizers.clip_grad_norm.rst b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.clip_grad_norm.rst new file mode 100644 index 000000000..ccd4924c5 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.optimizers.clip_grad_norm.rst @@ -0,0 +1,6 @@ +mlx.optimizers.clip\_grad\_norm +=============================== + +.. currentmodule:: mlx.optimizers + +.. autofunction:: clip_grad_norm \ No newline at end of file diff --git a/docs/build/html/_sources/python/_autosummary/mlx.utils.tree_reduce.rst b/docs/build/html/_sources/python/_autosummary/mlx.utils.tree_reduce.rst new file mode 100644 index 000000000..0bba35704 --- /dev/null +++ b/docs/build/html/_sources/python/_autosummary/mlx.utils.tree_reduce.rst @@ -0,0 +1,6 @@ +mlx.utils.tree\_reduce +====================== + +.. currentmodule:: mlx.utils + +.. autofunction:: tree_reduce \ No newline at end of file diff --git a/docs/build/html/_sources/python/metal.rst b/docs/build/html/_sources/python/metal.rst index 589ec0a82..cb2cdb38e 100644 --- a/docs/build/html/_sources/python/metal.rst +++ b/docs/build/html/_sources/python/metal.rst @@ -7,8 +7,10 @@ Metal :toctree: _autosummary is_available + device_info get_active_memory get_peak_memory + reset_peak_memory get_cache_memory set_memory_limit set_cache_limit diff --git a/docs/build/html/_sources/python/ops.rst b/docs/build/html/_sources/python/ops.rst index 7795512a0..177332c49 100644 --- a/docs/build/html/_sources/python/ops.rst +++ b/docs/build/html/_sources/python/ops.rst @@ -19,6 +19,7 @@ Operations arcsin arcsinh arctan + arctan2 arctanh argmax argmin @@ -28,11 +29,17 @@ Operations atleast_1d atleast_2d atleast_3d - broadcast_to + bitwise_and + bitwise_or + bitwise_xor block_masked_mm + block_sparse_mm + broadcast_to ceil clip concatenate + conj + conjugate convolve conv1d conv2d @@ -69,6 +76,7 @@ Operations isnan isneginf isposinf + left_shift less less_equal linspace @@ -105,6 +113,7 @@ Operations reciprocal repeat reshape + right_shift round rsqrt save diff --git a/docs/build/html/_sources/python/optimizers.rst b/docs/build/html/_sources/python/optimizers.rst index f437ddc15..84ab933ac 100644 --- a/docs/build/html/_sources/python/optimizers.rst +++ b/docs/build/html/_sources/python/optimizers.rst @@ -1,5 +1,7 @@ .. _optimizers: +.. currentmodule:: mlx.optimizers + Optimizers ========== @@ -34,3 +36,8 @@ model's parameters and the **optimizer state**. optimizers/optimizer optimizers/common_optimizers optimizers/schedulers + +.. autosummary:: + :toctree: _autosummary + + clip_grad_norm diff --git a/docs/build/html/_sources/python/tree_utils.rst b/docs/build/html/_sources/python/tree_utils.rst index dbd0ebce9..6dc60b47d 100644 --- a/docs/build/html/_sources/python/tree_utils.rst +++ b/docs/build/html/_sources/python/tree_utils.rst @@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and tree_unflatten tree_map tree_map_with_path + tree_reduce diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index b217302fd..607aaea4c 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,5 +1,5 @@ const DOCUMENTATION_OPTIONS = { - VERSION: '0.12.0', + VERSION: '0.13.0', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/allocator_8h.html b/docs/build/html/allocator_8h.html new file mode 100644 index 000000000..8e3b1e763 --- /dev/null +++ b/docs/build/html/allocator_8h.html @@ -0,0 +1,124 @@ + + + + + + + +MLX: mlx/allocator.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
allocator.h File Reference
+
+
+
#include <cstdlib>
+
+

Go to the source code of this file.

+ + + + + + + + +

+Classes

class  mlx::core::allocator::Buffer
 
class  mlx::core::allocator::Allocator
 
class  mlx::core::allocator::CommonAllocator
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::allocator
 
+ + + + + + + + + +

+Functions

Buffer mlx::core::allocator::malloc (size_t size)
 
void mlx::core::allocator::free (Buffer buffer)
 
Buffer mlx::core::allocator::malloc_or_wait (size_t size)
 
Allocatormlx::core::allocator::allocator ()
 
+
+ + + + diff --git a/docs/build/html/allocator_8h_source.html b/docs/build/html/allocator_8h_source.html new file mode 100644 index 000000000..917df829b --- /dev/null +++ b/docs/build/html/allocator_8h_source.html @@ -0,0 +1,191 @@ + + + + + + + +MLX: mlx/allocator.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
allocator.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <cstdlib>
+
6
+
+ +
8
+
9// Simple wrapper around buffer pointers
+
10// WARNING: Only Buffer objects constructed from and those that wrap
+
11// raw pointers from mlx::allocator are supported.
+
+
12class Buffer {
+
13 private:
+
14 void* ptr_;
+
15
+
16 public:
+
17 Buffer(void* ptr) : ptr_(ptr) {};
+
18
+
19 // Get the raw data pointer from the buffer
+
20 void* raw_ptr();
+
21
+
22 // Get the buffer pointer from the buffer
+
+
23 const void* ptr() const {
+
24 return ptr_;
+
25 };
+
+
+
26 void* ptr() {
+
27 return ptr_;
+
28 };
+
+
29};
+
+
30
+
31Buffer malloc(size_t size);
+
32
+
33void free(Buffer buffer);
+
34
+
35// Wait for running tasks to finish and free up memory
+
36// if allocation fails
+ +
38
+
+
39class Allocator {
+
41 public:
+
42 virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
+
43 virtual void free(Buffer buffer) = 0;
+
44
+
45 Allocator() = default;
+
46 Allocator(const Allocator& other) = delete;
+
47 Allocator(Allocator&& other) = delete;
+
48 Allocator& operator=(const Allocator& other) = delete;
+
49 Allocator& operator=(Allocator&& other) = delete;
+
50 virtual ~Allocator() = default;
+
51};
+
+
52
+ +
54
+
+
55class CommonAllocator : public Allocator {
+
57 public:
+
58 virtual Buffer malloc(size_t size, bool allow_swap = false) override;
+
59 virtual void free(Buffer buffer) override;
+
60
+
61 private:
+
62 CommonAllocator() = default;
+ +
64};
+
+
65
+
66} // namespace mlx::core::allocator
+
+
Definition allocator.h:39
+
Allocator & operator=(const Allocator &other)=delete
+
Allocator & operator=(Allocator &&other)=delete
+ +
Allocator(Allocator &&other)=delete
+ +
virtual Buffer malloc(size_t size, bool allow_swap=false)=0
Abstract base class for a memory allocator.
+
Allocator(const Allocator &other)=delete
+
virtual void free(Buffer buffer)=0
+
Definition allocator.h:12
+ +
const void * ptr() const
Definition allocator.h:23
+
Buffer(void *ptr)
Definition allocator.h:17
+
void * ptr()
Definition allocator.h:26
+
Definition allocator.h:55
+
virtual Buffer malloc(size_t size, bool allow_swap=false) override
A general CPU allocator.
+
virtual void free(Buffer buffer) override
+ +
Definition allocator.h:7
+
Buffer malloc(size_t size)
+
void free(Buffer buffer)
+
Buffer malloc_or_wait(size_t size)
+
Allocator & allocator()
+
+ + + + diff --git a/docs/build/html/annotated.html b/docs/build/html/annotated.html new file mode 100644 index 000000000..220efdb5b --- /dev/null +++ b/docs/build/html/annotated.html @@ -0,0 +1,445 @@ + + + + + + + +MLX: Class List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + +
+ +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ +
+
Class List
+
+
+
Here are the classes, structs, unions and interfaces with brief descriptions:
+
[detail level 12345]
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
 Nmetal
 Nmlx
 NMPS
 Npocketfft
 C_MLX_BFloat16
 CAbs
 CAdd
 CAnd
 CArcCos
 CArcCosh
 CArcSin
 CArcSinh
 CArcTan
 CArcTan2
 CArcTanh
 CBitwiseAnd
 CBitwiseOr
 CBitwiseXor
 Cbool4_or_uint
 CCeil
 Ccomplex64_t
 CConjugate
 CCos
 CCosh
 CDivide
 CEqual
 CErf
 CErfInv
 CExp
 CExpm1
 CFloor
 CGreater
 CGreaterEqual
 CIndices
 CLeftShift
 CLess
 CLessEqual
 CLimits
 CLimits< bfloat16_t >
 CLimits< bool >
 CLimits< float >
 CLimits< half >
 CLimits< int16_t >
 CLimits< int32_t >
 CLimits< int64_t >
 CLimits< int8_t >
 CLimits< uint16_t >
 CLimits< uint32_t >
 CLimits< uint64_t >
 CLimits< uint8_t >
 CLog
 CLog10
 CLog1p
 CLog2
 CLogAddExp
 CLogicalAnd
 CLogicalNot
 CLogicalOr
 CMax
 CMaximum
 CMin
 CMinimum
 Cmlx_atomic
 Cmlx_atomic< T, enable_if_t< is_metal_atomic< T > > >
 CMLXConvParams
 CMLXScaledDotProductAttentionParams
 CMultiply
 CNaNEqual
 CNegative
 CNone
 CNotEqual
 COr
 CPower
 CProd
 CRemainder
 CRightShift
 CRound
 CRsqrt
 CSelect
 CSigmoid
 CSign
 CSin
 CSinh
 CSqrt
 CSquare
 CSubtract
 CSum
 CTan
 CTanh
+
+
+ + + + diff --git a/docs/build/html/arange_8h.html b/docs/build/html/arange_8h.html new file mode 100644 index 000000000..804fdf88e --- /dev/null +++ b/docs/build/html/arange_8h.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/backend/common/arange.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
arange.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Functions

void mlx::core::arange (const std::vector< array > &inputs, array &out, double start, double step)
 
+
+ + + + diff --git a/docs/build/html/arange_8h_source.html b/docs/build/html/arange_8h_source.html new file mode 100644 index 000000000..2f042106a --- /dev/null +++ b/docs/build/html/arange_8h_source.html @@ -0,0 +1,192 @@ + + + + + + + +MLX: mlx/backend/common/arange.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
arange.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/allocator.h"
+
6#include "mlx/array.h"
+
7
+
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12template <typename T>
+
13void arange(T start, T next, array& out, size_t size) {
+
14 auto ptr = out.data<T>();
+
15 auto step_size = next - start;
+
16 for (int i = 0; i < size; ++i) {
+
17 ptr[i] = start;
+
18 start += step_size;
+
19 }
+
20}
+
21
+
22} // namespace
+
23
+
+
24void arange(
+
25 const std::vector<array>& inputs,
+
26 array& out,
+
27 double start,
+
28 double step) {
+
29 assert(inputs.size() == 0);
+ +
31 switch (out.dtype()) {
+
32 case bool_:
+
33 throw std::runtime_error("Bool type unsupported for arange.");
+
34 break;
+
35 case uint8:
+
36 arange<uint8_t>(start, start + step, out, out.size());
+
37 break;
+
38 case uint16:
+
39 arange<uint16_t>(start, start + step, out, out.size());
+
40 break;
+
41 case uint32:
+
42 arange<uint32_t>(start, start + step, out, out.size());
+
43 break;
+
44 case uint64:
+
45 arange<uint64_t>(start, start + step, out, out.size());
+
46 break;
+
47 case int8:
+
48 arange<int8_t>(start, start + step, out, out.size());
+
49 break;
+
50 case int16:
+
51 arange<int16_t>(start, start + step, out, out.size());
+
52 break;
+
53 case int32:
+
54 arange<int32_t>(start, start + step, out, out.size());
+
55 break;
+
56 case int64:
+
57 arange<int64_t>(start, start + step, out, out.size());
+
58 break;
+
59 case float16:
+
60 arange<float16_t>(start, start + step, out, out.size());
+
61 break;
+
62 case float32:
+
63 arange<float>(start, start + step, out, out.size());
+
64 break;
+
65 case bfloat16:
+
66 arange<bfloat16_t>(start, start + step, out, out.size());
+
67 break;
+
68 case complex64:
+
69 arange<complex64_t>(start, start + step, out, out.size());
+
70 break;
+
71 }
+
72}
+
+
73
+
74} // namespace mlx::core
+ + +
BufferHolder * next
Definition allocator.h:37
+
Definition array.h:20
+
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
+
size_t size() const
The number of elements in the array.
Definition array.h:84
+
void set_data(allocator::Buffer buffer, deleter_t d=allocator::free)
+
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+
+ + + + diff --git a/docs/build/html/array_8h.html b/docs/build/html/array_8h.html new file mode 100644 index 000000000..bdfcf22d8 --- /dev/null +++ b/docs/build/html/array_8h.html @@ -0,0 +1,138 @@ + + + + + + + +MLX: mlx/array.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
array.h File Reference
+
+
+
#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <vector>
+#include "mlx/allocator.h"
+#include "mlx/dtype.h"
+#include "mlx/event.h"
+
+

Go to the source code of this file.

+ + + + + + + + + + +

+Classes

class  mlx::core::array
 
struct  mlx::core::array::ArrayIterator
 
struct  mlx::core::array::Data
 
struct  mlx::core::array::Flags
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + + + + +

+Typedefs

using mlx::core::deleter_t = std::function<void(allocator::Buffer)>
 
template<typename... T>
using mlx::core::enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>
 
+ + + + + + + +

+Variables

template<typename T >
constexpr bool mlx::core::is_array_v
 
template<typename... T>
constexpr bool mlx::core::is_arrays_v = (is_array_v<T> && ...)
 
+
+ + + + diff --git a/docs/build/html/array_8h_source.html b/docs/build/html/array_8h_source.html new file mode 100644 index 000000000..4b15bc560 --- /dev/null +++ b/docs/build/html/array_8h_source.html @@ -0,0 +1,842 @@ + + + + + + + +MLX: mlx/array.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
array.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2#pragma once
+
3
+
4#include <algorithm>
+
5#include <cstdint>
+
6#include <functional>
+
7#include <memory>
+
8#include <vector>
+
9
+
10#include "mlx/allocator.h"
+
11#include "mlx/dtype.h"
+
12#include "mlx/event.h"
+
13
+
14namespace mlx::core {
+
15
+
16// Forward declaration
+
17class Primitive;
+
18using deleter_t = std::function<void(allocator::Buffer)>;
+
19
+
+
20class array {
+
21 /* An array is really a node in a graph. It contains a shared ArrayDesc
+
22 * object */
+
23
+
24 public:
+
26 template <typename T>
+
27 explicit array(T val, Dtype dtype = TypeToDtype<T>());
+
28
+
29 /* Special case since std::complex can't be implicitly converted to other
+
30 * types. */
+
31 explicit array(const std::complex<float>& val, Dtype dtype = complex64);
+
32
+
33 template <typename It>
+
34 array(
+
35 It data,
+
36 std::vector<int> shape,
+
37 Dtype dtype =
+
38 TypeToDtype<typename std::iterator_traits<It>::value_type>());
+
39
+
40 template <typename T>
+
41 array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
+
42
+
43 /* Special case so empty lists default to float32. */
+
44 array(std::initializer_list<float> data);
+
45
+
46 /* Special case so array({}, type) is an empty array. */
+
47 array(std::initializer_list<int> data, Dtype dtype);
+
48
+
49 template <typename T>
+
50 array(
+
51 std::initializer_list<T> data,
+
52 std::vector<int> shape,
+ +
54
+
55 /* Build an array from a buffer */
+ + +
58 std::vector<int> shape,
+ +
60 deleter_t deleter = allocator::free);
+
61
+
63 array& operator=(const array& other) && = delete;
+
64 array& operator=(array&& other) && = delete;
+
65
+
67 array& operator=(array&& other) & = default;
+
68 array(const array& other) = default;
+
69 array(array&& other) = default;
+
70
+
+
71 array& operator=(const array& other) & {
+
72 if (this->id() != other.id()) {
+
73 this->array_desc_ = other.array_desc_;
+
74 }
+
75 return *this;
+
76 };
+
+
77
+
+
79 size_t itemsize() const {
+
80 return size_of(dtype());
+
81 };
+
+
82
+
+
84 size_t size() const {
+
85 return array_desc_->size;
+
86 };
+
+
87
+
+
89 size_t nbytes() const {
+
90 return size() * itemsize();
+
91 };
+
+
92
+
+
94 size_t ndim() const {
+
95 return array_desc_->shape.size();
+
96 };
+
+
97
+
+
99 const std::vector<int>& shape() const {
+
100 return array_desc_->shape;
+
101 };
+
+
102
+
+
108 int shape(int dim) const {
+
109 return shape().at(dim < 0 ? dim + ndim() : dim);
+
110 };
+
+
111
+
+
113 const std::vector<size_t>& strides() const {
+
114 return array_desc_->strides;
+
115 };
+
+
116
+
+
122 size_t strides(int dim) const {
+
123 return strides().at(dim < 0 ? dim + ndim() : dim);
+
124 };
+
+
125
+
+
127 Dtype dtype() const {
+
128 return array_desc_->dtype;
+
129 };
+
+
130
+
132 void eval();
+
133
+
135 template <typename T>
+
136 T item();
+
137
+
138 template <typename T>
+
139 T item() const;
+
140
+
+ +
142 using iterator_category = std::random_access_iterator_tag;
+
143 using difference_type = size_t;
+
144 using value_type = const array;
+ +
146
+
147 explicit ArrayIterator(const array& arr, int idx = 0);
+
148
+ +
150
+
+ +
152 idx += diff;
+
153 return *this;
+
154 }
+
+
155
+
+ +
157 idx++;
+
158 return *this;
+
159 }
+
+
160
+
+
161 friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
+
162 return a.arr.id() == b.arr.id() && a.idx == b.idx;
+
163 };
+
+
+
164 friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
+
165 return !(a == b);
+
166 };
+
+
167
+
168 private:
+
169 const array& arr;
+
170 int idx;
+
171 };
+
+
172
+
+ +
174 return ArrayIterator(*this);
+
175 }
+
+
+ +
177 return ArrayIterator(*this, shape(0));
+
178 }
+
+
179
+ +
187 std::vector<int> shape,
+
188 Dtype dtype,
+
189 std::shared_ptr<Primitive> primitive,
+
190 std::vector<array> inputs);
+
191
+
192 static std::vector<array> make_arrays(
+
193 std::vector<std::vector<int>> shapes,
+
194 const std::vector<Dtype>& dtypes,
+
195 const std::shared_ptr<Primitive>& primitive,
+
196 const std::vector<array>& inputs);
+
197
+
+
199 std::uintptr_t id() const {
+
200 return reinterpret_cast<std::uintptr_t>(array_desc_.get());
+
201 }
+
+
202
+
+
204 std::uintptr_t primitive_id() const {
+
205 return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
+
206 }
+
+
207
+
+
208 struct Data {
+ + + +
213 // Not copyable
+
214 Data(const Data& d) = delete;
+
215 Data& operator=(const Data& d) = delete;
+
+ +
217 d(buffer);
+
218 }
+
+
219 };
+
+
220
+
+
221 struct Flags {
+
222 // True if there are no gaps in the underlying data. Each item
+
223 // in the underlying data buffer belongs to at least one index.
+
224 bool contiguous : 1;
+
225
+ + +
228 };
+
+
229
+
+ +
232 return *(array_desc_->primitive);
+
233 };
+
+
234
+
+
236 std::shared_ptr<Primitive>& primitive_ptr() const {
+
237 return array_desc_->primitive;
+
238 };
+
+
239
+
+
241 bool has_primitive() const {
+
242 return array_desc_->primitive != nullptr;
+
243 };
+
+
244
+
+
246 const std::vector<array>& inputs() const {
+
247 return array_desc_->inputs;
+
248 };
+
+
249
+
+
250 std::vector<array>& inputs() {
+
251 return array_desc_->inputs;
+
252 }
+
+
253
+
+
255 bool is_donatable() const {
+
256 return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
+
257 }
+
+
258
+
+
260 const std::vector<array>& siblings() const {
+
261 return array_desc_->siblings;
+
262 };
+
+
263
+
+
265 std::vector<array>& siblings() {
+
266 return array_desc_->siblings;
+
267 };
+
+
268
+
+
269 void set_siblings(std::vector<array> siblings, uint16_t position) {
+
270 array_desc_->siblings = std::move(siblings);
+
271 array_desc_->position = position;
+
272 }
+
+
273
+
+
276 std::vector<array> outputs() const {
+
277 auto idx = array_desc_->position;
+
278 std::vector<array> outputs;
+
279 outputs.reserve(siblings().size() + 1);
+
280 outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
+
281 outputs.push_back(*this);
+
282 outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
+
283 return outputs;
+
284 };
+
+
285
+
287 void detach();
+
288
+
+
290 const Flags& flags() const {
+
291 return array_desc_->flags;
+
292 };
+
+
293
+
+
295 size_t data_size() const {
+
296 return array_desc_->data_size;
+
297 };
+
+
298
+
+ +
300 return array_desc_->data->buffer;
+
301 };
+
+
+
302 const allocator::Buffer& buffer() const {
+
303 return array_desc_->data->buffer;
+
304 };
+
+
305
+
306 // Return a copy of the shared pointer
+
307 // to the array::Data struct
+
+
308 std::shared_ptr<Data> data_shared_ptr() const {
+
309 return array_desc_->data;
+
310 }
+
+
311 // Return a raw pointer to the arrays data
+
312 template <typename T>
+
+
313 T* data() {
+
314 return static_cast<T*>(array_desc_->data_ptr);
+
315 };
+
+
316
+
317 template <typename T>
+
+
318 const T* data() const {
+
319 return static_cast<T*>(array_desc_->data_ptr);
+
320 };
+
+
321
+ +
323
+
+
324 bool is_available() const {
+
325 return status() == Status::available;
+
326 }
+
+
+
327 const Status status() const {
+
328 return array_desc_->status;
+
329 }
+
+
330
+
+
331 void set_status(Status s) const {
+
332 array_desc_->status = s;
+
333 }
+
+
334
+
335 // Get the array's shared event
+
+
336 Event& event() const {
+
337 return array_desc_->event;
+
338 }
+
+
339
+
340 // Attach an event to a not yet evaluated array
+
+
341 void attach_event(Event e) const {
+
342 array_desc_->event = std::move(e);
+
343 }
+
+
344
+
345 // Mark the array as a tracer array (true) or not.
+
+ +
347 array_desc_->is_tracer = is_tracer;
+
348 }
+
+
349 // Check if the array is a tracer array
+
350 bool is_tracer() const;
+
351
+ +
353
+ + +
356 size_t data_size,
+
357 std::vector<size_t> strides,
+
358 Flags flags,
+ +
360
+ +
362 const array& other,
+
363 const std::vector<size_t>& strides,
+
364 Flags flags,
+
365 size_t data_size,
+
366 size_t offset = 0);
+
367
+
368 void copy_shared_buffer(const array& other);
+
369
+ +
371 array other,
+
372 const std::vector<size_t>& strides,
+
373 Flags flags,
+
374 size_t data_size,
+
375 size_t offset = 0);
+
376
+ +
378
+
+
379 void overwrite_descriptor(const array& other) {
+
380 array_desc_ = other.array_desc_;
+
381 }
+
+
382
+ +
384
+
385 private:
+
386 // Initialize the arrays data
+
387 template <typename It>
+
388 void init(const It src);
+
389
+
390 struct ArrayDesc {
+
391 std::vector<int> shape;
+
392 std::vector<size_t> strides;
+
393 size_t size;
+
394 Dtype dtype;
+
395 std::shared_ptr<Primitive> primitive;
+
396
+
397 Status status;
+
398
+
399 // An event on the array used for synchronization
+
400 Event event;
+
401
+
402 // Indicates an array is being used in a graph transform
+
403 // and should not be detached from the graph
+
404 bool is_tracer{false};
+
405
+
406 // This is a shared pointer so that *different* arrays
+
407 // can share the underlying data buffer.
+
408 std::shared_ptr<Data> data;
+
409
+
410 // Properly offset data pointer
+
411 void* data_ptr{nullptr};
+
412
+
413 // The size in elements of the data buffer the array accesses
+
414 // This can be different than the actual size of the array if it
+
415 // has been broadcast or irregularly strided.
+
416 size_t data_size;
+
417
+
418 // Contains useful meta data about the array
+
419 Flags flags;
+
420
+
421 std::vector<array> inputs;
+
422 // An array to keep track of the siblings from a multi-output
+
423 // primitive.
+
424 std::vector<array> siblings;
+
425 // The arrays position in the output list
+
426 uint32_t position{0};
+
427
+
428 explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
+
429
+
430 explicit ArrayDesc(
+
431 std::vector<int> shape,
+
432 Dtype dtype,
+
433 std::shared_ptr<Primitive> primitive,
+
434 std::vector<array> inputs);
+
435
+
436 ~ArrayDesc();
+
437
+
438 private:
+
439 // Initialize size, strides, and other metadata
+
440 void init();
+
441 };
+
442
+
443 // The ArrayDesc contains the details of the materialized array including the
+
444 // shape, strides, the data type. It also includes
+
445 // the primitive which knows how to compute the array's data from its inputs
+
446 // and the list of array's inputs for the primitive.
+
447 std::shared_ptr<ArrayDesc> array_desc_;
+
448};
+
+
449
+
450template <typename T>
+
+
451array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
+
452 : array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
+
453 init(&val);
+
454}
+
+
455
+
456template <typename It>
+
+ +
458 It data,
+
459 std::vector<int> shape,
+
460 Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
+
461 array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
+
462 init(data);
+
463}
+
+
464
+
465template <typename T>
+
+ +
467 std::initializer_list<T> data,
+
468 Dtype dtype /* = TypeToDtype<T>() */)
+
469 : array_desc_(std::make_shared<ArrayDesc>(
+
470 std::vector<int>{static_cast<int>(data.size())},
+
471 dtype)) {
+
472 init(data.begin());
+
473}
+
+
474
+
475template <typename T>
+
+ +
477 std::initializer_list<T> data,
+
478 std::vector<int> shape,
+
479 Dtype dtype /* = TypeToDtype<T>() */)
+
480 : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
+
481 if (data.size() != size()) {
+
482 throw std::invalid_argument(
+
483 "Data size and provided shape mismatch in array construction.");
+
484 }
+
485 init(data.begin());
+
486}
+
+
487
+
488template <typename T>
+
+ +
490 if (size() != 1) {
+
491 throw std::invalid_argument("item can only be called on arrays of size 1.");
+
492 }
+
493 eval();
+
494 return *data<T>();
+
495}
+
+
496
+
497template <typename T>
+
+
498T array::item() const {
+
499 if (size() != 1) {
+
500 throw std::invalid_argument("item can only be called on arrays of size 1.");
+
501 }
+
502 if (status() == Status::unscheduled) {
+
503 throw std::invalid_argument(
+
504 "item() const can only be called on evaled arrays");
+
505 }
+
506 const_cast<array*>(this)->eval();
+
507 return *data<T>();
+
508}
+
+
509
+
510template <typename It>
+
511void array::init(It src) {
+ +
513 switch (dtype()) {
+
514 case bool_:
+
515 std::copy(src, src + size(), data<bool>());
+
516 break;
+
517 case uint8:
+
518 std::copy(src, src + size(), data<uint8_t>());
+
519 break;
+
520 case uint16:
+
521 std::copy(src, src + size(), data<uint16_t>());
+
522 break;
+
523 case uint32:
+
524 std::copy(src, src + size(), data<uint32_t>());
+
525 break;
+
526 case uint64:
+
527 std::copy(src, src + size(), data<uint64_t>());
+
528 break;
+
529 case int8:
+
530 std::copy(src, src + size(), data<int8_t>());
+
531 break;
+
532 case int16:
+
533 std::copy(src, src + size(), data<int16_t>());
+
534 break;
+
535 case int32:
+
536 std::copy(src, src + size(), data<int32_t>());
+
537 break;
+
538 case int64:
+
539 std::copy(src, src + size(), data<int64_t>());
+
540 break;
+
541 case float16:
+
542 std::copy(src, src + size(), data<float16_t>());
+
543 break;
+
544 case float32:
+
545 std::copy(src, src + size(), data<float>());
+
546 break;
+
547 case bfloat16:
+
548 std::copy(src, src + size(), data<bfloat16_t>());
+
549 break;
+
550 case complex64:
+
551 std::copy(src, src + size(), data<complex64_t>());
+
552 break;
+
553 }
+
554}
+
555
+
556/* Utilities for determining whether a template parameter is array. */
+
557template <typename T>
+
558inline constexpr bool is_array_v =
+
559 std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
+
560
+
561template <typename... T>
+
562inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
+
563
+
564template <typename... T>
+
565using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
+
566
+
567} // namespace mlx::core
+ +
Definition event.h:11
+
Definition primitives.h:48
+
Definition allocator.h:12
+
Definition array.h:20
+
void attach_event(Event e) const
Definition array.h:341
+
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:290
+
Event & event() const
Definition array.h:336
+
static std::vector< array > make_arrays(std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
+
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
+
Status
Definition array.h:322
+
@ available
Definition array.h:322
+
@ unscheduled
Definition array.h:322
+
@ scheduled
Definition array.h:322
+
void set_data(allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)
+
void eval()
Evaluate the array.
+
void copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
+
const std::vector< array > & inputs() const
The array's inputs.
Definition array.h:246
+
array(const array &other)=default
+
std::vector< array > outputs() const
The outputs of the array's primitive (i.e.
Definition array.h:276
+ +
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
+
void move_shared_buffer(array other)
+
array(std::initializer_list< float > data)
+
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:255
+
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
+
std::shared_ptr< Primitive > & primitive_ptr() const
A shared pointer to the array's primitive.
Definition array.h:236
+
int shape(int dim) const
Get the size of the corresponding dimension.
Definition array.h:108
+
size_t ndim() const
The number of dimensions of the array.
Definition array.h:94
+
size_t size() const
The number of elements in the array.
Definition array.h:84
+
array(allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)
+
array & operator=(array &&other) &&=delete
+
array & operator=(const array &other) &
Definition array.h:71
+
ArrayIterator end() const
Definition array.h:176
+
array(std::initializer_list< int > data, Dtype dtype)
+
void set_data(allocator::Buffer buffer, deleter_t d=allocator::free)
+
const allocator::Buffer & buffer() const
Definition array.h:302
+
void set_status(Status s) const
Definition array.h:331
+
array(const std::complex< float > &val, Dtype dtype=complex64)
+
std::vector< array > & siblings()
The array's siblings.
Definition array.h:265
+
T * data()
Definition array.h:313
+
array(T val, Dtype dtype=TypeToDtype< T >())
Construct a scalar array with zero dimensions.
Definition array.h:451
+
ArrayIterator begin() const
Definition array.h:173
+
Primitive & primitive() const
The array's primitive.
Definition array.h:231
+
void detach()
Detach the array from the graph.
+
array & operator=(const array &other) &&=delete
Assignment to rvalue does not compile.
+
void set_siblings(std::vector< array > siblings, uint16_t position)
Definition array.h:269
+
T item()
Get the value from a scalar array.
Definition array.h:489
+
size_t strides(int dim) const
Get the stride of the corresponding dimension.
Definition array.h:122
+
void copy_shared_buffer(const array &other)
+
void overwrite_descriptor(const array &other)
Definition array.h:379
+
const T * data() const
Definition array.h:318
+
bool has_primitive() const
Check if the array has an attached primitive or is a leaf node.
Definition array.h:241
+
allocator::Buffer & buffer()
Definition array.h:299
+
array(array &&other)=default
+
std::shared_ptr< Data > data_shared_ptr() const
Definition array.h:308
+
void move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
+
const std::vector< array > & siblings() const
The array's siblings.
Definition array.h:260
+
std::vector< array > & inputs()
Definition array.h:250
+
array & operator=(array &&other) &=default
Default copy and move constructors otherwise.
+
array(std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
The following methods should be used with caution.
+
const Status status() const
Definition array.h:327
+
std::uintptr_t id() const
A unique identifier for an array.
Definition array.h:199
+
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
+
bool is_available() const
Definition array.h:324
+
void set_tracer(bool is_tracer)
Definition array.h:346
+
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:79
+
std::uintptr_t primitive_id() const
A unique identifier for an arrays primitive.
Definition array.h:204
+
bool is_tracer() const
+
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:295
+ + +
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
+
Buffer malloc(size_t size)
+
void free(Buffer buffer)
+
Definition allocator.h:7
+
constexpr bool is_array_v
Definition array.h:558
+
constexpr Dtype bool_
Definition dtype.h:60
+
std::function< void(allocator::Buffer)> deleter_t
Definition array.h:18
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr bool is_arrays_v
Definition array.h:562
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
uint8_t size_of(const Dtype &t)
Definition dtype.h:95
+
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:565
+
constexpr Dtype complex64
Definition dtype.h:75
+
Definition dtype.h:15
+
Definition dtype.h:102
+
Definition array.h:141
+ +
friend bool operator==(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:161
+
std::random_access_iterator_tag iterator_category
Definition array.h:142
+
ArrayIterator & operator++()
Definition array.h:156
+
friend bool operator!=(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:164
+
ArrayIterator(const array &arr, int idx=0)
+
size_t difference_type
Definition array.h:143
+
const array value_type
Definition array.h:144
+
ArrayIterator & operator+(difference_type diff)
Definition array.h:151
+
Definition array.h:208
+
~Data()
Definition array.h:216
+
deleter_t d
Definition array.h:210
+
Data(const Data &d)=delete
+
Data & operator=(const Data &d)=delete
+
Data(allocator::Buffer buffer, deleter_t d=allocator::free)
Definition array.h:211
+
allocator::Buffer buffer
Definition array.h:209
+
Definition array.h:221
+
bool row_contiguous
Definition array.h:226
+
bool col_contiguous
Definition array.h:227
+
bool contiguous
Definition array.h:224
+
+ + + + diff --git a/docs/build/html/atomic_8h.html b/docs/build/html/atomic_8h.html new file mode 100644 index 000000000..f9df80082 --- /dev/null +++ b/docs/build/html/atomic_8h.html @@ -0,0 +1,521 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/atomic.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
atomic.h File Reference
+
+
+
#include <metal_atomic>
+#include <metal_stdlib>
+#include "mlx/backend/metal/kernels/bf16.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Classes

struct  mlx_atomic< T, typename >
 
struct  mlx_atomic< T, enable_if_t< is_metal_atomic< T > > >
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC T mlx_atomic_load_explicit (device mlx_atomic< T > *object, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_store_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_or_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit (device mlx_atomic< T > *object, T val, uint offset)
 
template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
 
template<>
METAL_FUNC void mlx_atomic_fetch_min_explicit< float > (device mlx_atomic< float > *object, float val, uint offset)
 
template<>
METAL_FUNC void mlx_atomic_fetch_max_explicit< float > (device mlx_atomic< float > *object, float val, uint offset)
 
template<typename T , enable_if_t<!is_metal_atomic< T >, bool > = true>
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > *object, thread uint *expected, uint val, uint offset)
 
+ + + + +

+Variables

template<typename T >
constexpr constant bool is_metal_atomic
 
+

Function Documentation

+ +

◆ mlx_atomic_compare_exchange_weak_explicit() [1/2]

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > * object,
thread T * expected,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_compare_exchange_weak_explicit() [2/2]

+ +
+
+
+template<typename T , enable_if_t<!is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit (device mlx_atomic< T > * object,
thread uint * expected,
uint val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_add_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_add_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_and_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_and_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_max_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_max_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_max_explicit< float >()

+ +
+
+
+template<>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_max_explicit< float > (device mlx_atomic< float > * object,
float val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_min_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_min_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_min_explicit< float >()

+ +
+
+
+template<>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_min_explicit< float > (device mlx_atomic< float > * object,
float val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_mul_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_mul_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_fetch_or_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_fetch_or_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_load_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + +
METAL_FUNC T mlx_atomic_load_explicit (device mlx_atomic< T > * object,
uint offset )
+
+ +
+
+ +

◆ mlx_atomic_store_explicit()

+ +
+
+
+template<typename T , enable_if_t< is_metal_atomic< T >, bool > = true>
+ + + + + + + + + + + + + + + + +
METAL_FUNC void mlx_atomic_store_explicit (device mlx_atomic< T > * object,
T val,
uint offset )
+
+ +
+
+

Variable Documentation

+ +

◆ is_metal_atomic

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool is_metal_atomic
+
+constexpr
+
+Initial value:
= _disjunction<
+
is_same<T, int>,
+
is_same<T, uint>,
+
is_same<T, ulong>,
+
is_same<T, float>>::value
+
+
+
+
+ + + + diff --git a/docs/build/html/atomic_8h_source.html b/docs/build/html/atomic_8h_source.html new file mode 100644 index 000000000..b1e920a46 --- /dev/null +++ b/docs/build/html/atomic_8h_source.html @@ -0,0 +1,478 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/atomic.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
atomic.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_atomic>
+
6#include <metal_stdlib>
+ +
8
+
9using namespace metal;
+
10
+
12// Atomic utils
+
14
+
15#pragma METAL internals : enable
+
16template <typename T>
+
17constexpr constant bool is_metal_atomic = _disjunction<
+
18 is_same<T, int>,
+
19 is_same<T, uint>,
+
20 is_same<T, ulong>,
+
21 is_same<T, float>>::value;
+
22
+
23#pragma METAL internals : disable
+
24
+
25template <typename T, typename = void>
+
+
26struct mlx_atomic {
+
27 atomic<uint> val;
+
28};
+
+
29
+
30template <typename T>
+
+
31struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
+
32 atomic<T> val;
+
33};
+
+
34
+
36// Native metal atomics
+
38
+
39template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
40METAL_FUNC T
+
+
41mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
+
42 return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
+
43}
+
+
44
+
45template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
46METAL_FUNC void
+
+
47mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
48 atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
+
49}
+
+
50
+
51template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
53 device mlx_atomic<T>* object,
+
54 T val,
+
55 uint offset) {
+
56 atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
+
57}
+
+
58
+
59template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
60METAL_FUNC void
+
+
61mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
62 atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
+
63}
+
+
64
+
65template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
67 device mlx_atomic<T>* object,
+
68 T val,
+
69 uint offset) {
+
70 atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
+
71}
+
+
72
+
73template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
75 device mlx_atomic<T>* object,
+
76 T val,
+
77 uint offset) {
+
78 atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
+
79}
+
+
80
+
81template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
83 device mlx_atomic<T>* object,
+
84 T val,
+
85 uint offset) {
+
86 atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
+
87}
+
+
88
+
89template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
91 device mlx_atomic<T>* object,
+
92 T val,
+
93 uint offset) {
+
94 T expected = mlx_atomic_load_explicit(object, offset);
+ +
96 object, &expected, val * expected, offset)) {
+
97 }
+
98}
+
+
99
+
100template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
+
+ +
102 device mlx_atomic<T>* object,
+
103 thread T* expected,
+
104 T val,
+
105 uint offset) {
+
106 return atomic_compare_exchange_weak_explicit(
+
107 &(object[offset].val),
+
108 expected,
+
109 val,
+
110 memory_order_relaxed,
+
111 memory_order_relaxed);
+
112}
+
+
113
+
114// Specialization for float since it does not atomic_fetch_min_explicit
+
115template <>
+
+ +
117 device mlx_atomic<float>* object,
+
118 float val,
+
119 uint offset) {
+
120 float expected = mlx_atomic_load_explicit(object, offset);
+
121 while (val < expected) {
+ +
123 object, &expected, val, offset)) {
+
124 return;
+
125 }
+
126 }
+
127}
+
+
128
+
129// Specialization for float since it does not atomic_fetch_max_explicit
+
130template <>
+
+ +
132 device mlx_atomic<float>* object,
+
133 float val,
+
134 uint offset) {
+
135 float expected = mlx_atomic_load_explicit(object, offset);
+
136 while (val > expected) {
+ +
138 object, &expected, val, offset)) {
+
139 return;
+
140 }
+
141 }
+
142}
+
+
143
+
145// Custom atomics
+
147
+
148namespace {
+
149
+
150template <typename T>
+
151constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
+
152
+
153template <typename T>
+
154union uint_or_packed {
+
155 T val[packing_size<T>];
+
156 uint bits;
+
157};
+
158
+
159template <typename T, typename Op>
+
160struct mlx_atomic_update_helper {
+
161 uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
+
162 Op op;
+
163 init.val[elem_offset] = op(update, init.val[elem_offset]);
+
164 return init.bits;
+
165 }
+
166};
+
167
+
168template <typename T, typename Op>
+
169METAL_FUNC void mlx_atomic_update_and_store(
+
170 device mlx_atomic<T>* object,
+
171 T update,
+
172 uint offset) {
+
173 uint pack_offset = offset / packing_size<T>;
+
174 uint elem_offset = offset % packing_size<T>;
+
175
+
176 mlx_atomic_update_helper<T, Op> helper;
+
177 uint_or_packed<T> expected;
+
178 expected.bits =
+
179 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
+
180
+
181 while (Op::condition(update, expected.val[elem_offset]) &&
+ +
183 object,
+
184 &(expected.bits),
+
185 helper(expected, update, elem_offset),
+
186 pack_offset)) {
+
187 }
+
188}
+
189
+
190template <typename T>
+
191struct __None {
+
192 static bool condition(T a, T b) {
+
193#pragma unused(a)
+
194#pragma unused(b)
+
195 return true;
+
196 }
+
197
+
198 T operator()(T a, T b) {
+
199#pragma unused(b)
+
200 return a;
+
201 }
+
202};
+
203
+
204template <typename T>
+
205struct __Add {
+
206 static bool condition(T a, T b) {
+
207#pragma unused(a)
+
208#pragma unused(b)
+
209 return true;
+
210 }
+
211
+
212 T operator()(T a, T b) {
+
213 return a + b;
+
214 }
+
215};
+
216
+
217template <typename T>
+
218struct __Mul {
+
219 static bool condition(T a, T b) {
+
220#pragma unused(a)
+
221 return b != 0;
+
222 }
+
223
+
224 T operator()(T a, T b) {
+
225 return a * b;
+
226 }
+
227};
+
228
+
229template <typename T>
+
230struct __Max {
+
231 static bool condition(T a, T b) {
+
232 return a > b;
+
233 }
+
234
+
235 T operator()(T a, T b) {
+
236 return max(a, b);
+
237 }
+
238};
+
239
+
240template <typename T>
+
241struct __Min {
+
242 static bool condition(T a, T b) {
+
243 return a < b;
+
244 }
+
245
+
246 T operator()(T a, T b) {
+
247 return min(a, b);
+
248 }
+
249};
+
250
+
251} // namespace
+
252
+
253template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
254METAL_FUNC T
+
255mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
+
256 uint pack_offset = offset / sizeof(T);
+
257 uint elem_offset = offset % sizeof(T);
+
258 uint_or_packed<T> packed_val;
+
259 packed_val.bits =
+
260 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
+
261 return packed_val.val[elem_offset];
+
262}
+
263
+
264template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
265METAL_FUNC void
+
266mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
267 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
+
268}
+
269
+
270template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
271METAL_FUNC void mlx_atomic_fetch_and_explicit(
+
272 device mlx_atomic<T>* object,
+
273 T val,
+
274 uint offset) {
+
275 uint pack_offset = offset / packing_size<T>;
+
276 uint elem_offset = offset % packing_size<T>;
+
277 uint_or_packed<T> identity;
+
278 identity.bits = __UINT32_MAX__;
+
279 identity.val[elem_offset] = val;
+
280
+
281 atomic_fetch_and_explicit(
+
282 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
+
283}
+
284
+
285template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
286METAL_FUNC void
+
287mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
+
288 uint pack_offset = offset / packing_size<T>;
+
289 uint elem_offset = offset % packing_size<T>;
+
290 uint_or_packed<T> identity;
+
291 identity.bits = 0;
+
292 identity.val[elem_offset] = val;
+
293
+
294 atomic_fetch_or_explicit(
+
295 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
+
296}
+
297
+
298template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
299METAL_FUNC void mlx_atomic_fetch_min_explicit(
+
300 device mlx_atomic<T>* object,
+
301 T val,
+
302 uint offset) {
+
303 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
+
304}
+
305
+
306template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
307METAL_FUNC void mlx_atomic_fetch_max_explicit(
+
308 device mlx_atomic<T>* object,
+
309 T val,
+
310 uint offset) {
+
311 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
+
312}
+
313
+
314template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
315METAL_FUNC void mlx_atomic_fetch_add_explicit(
+
316 device mlx_atomic<T>* object,
+
317 T val,
+
318 uint offset) {
+
319 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
+
320}
+
321
+
322template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
323METAL_FUNC void mlx_atomic_fetch_mul_explicit(
+
324 device mlx_atomic<T>* object,
+
325 T val,
+
326 uint offset) {
+
327 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
+
328}
+
329
+
330template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
+
+ +
332 device mlx_atomic<T>* object,
+
333 thread uint* expected,
+
334 uint val,
+
335 uint offset) {
+
336 return atomic_compare_exchange_weak_explicit(
+
337 &(object[offset].val),
+
338 expected,
+
339 val,
+
340 memory_order_relaxed,
+
341 memory_order_relaxed);
+
342}
+
+
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:82
+
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:131
+
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:52
+
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, uint offset)
Definition atomic.h:41
+
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:47
+
constexpr constant bool is_metal_atomic
Definition atomic.h:17
+
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:61
+
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:116
+
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:74
+
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:66
+
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:90
+
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
Definition atomic.h:101
+ +
Op op
Definition binary.h:139
+
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
+
Definition bf16.h:265
+
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
+
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
+
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
+ +
Definition atomic.h:26
+
atomic< uint > val
Definition atomic.h:27
+
+ + + + diff --git a/docs/build/html/backend_2accelerate_2utils_8h.html b/docs/build/html/backend_2accelerate_2utils_8h.html new file mode 100644 index 000000000..60d5dacd1 --- /dev/null +++ b/docs/build/html/backend_2accelerate_2utils_8h.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/backend/accelerate/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
utils.h File Reference
+
+
+
#include <vecLib/BNNS/bnns.h>
+#include "mlx/dtype.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Functions

BNNSDataType mlx::core::to_bnns_dtype (Dtype mlx_dtype)
 
+
+ + + + diff --git a/docs/build/html/backend_2accelerate_2utils_8h_source.html b/docs/build/html/backend_2accelerate_2utils_8h_source.html new file mode 100644 index 000000000..02e8f28ac --- /dev/null +++ b/docs/build/html/backend_2accelerate_2utils_8h_source.html @@ -0,0 +1,134 @@ + + + + + + + +MLX: mlx/backend/accelerate/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <vecLib/BNNS/bnns.h>
+
6#include "mlx/dtype.h"
+
7
+
8namespace mlx::core {
+
9
+
+
10BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
+
11 uint32_t size_bits = size_of(mlx_dtype) * 8;
+
12 switch (kindof(mlx_dtype)) {
+
13 case Dtype::Kind::b:
+
14 return BNNSDataTypeBoolean;
+
15 case Dtype::Kind::u:
+
16 return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
+
17 case Dtype::Kind::i:
+
18 return BNNSDataType(BNNSDataTypeIntBit | size_bits);
+
19 case Dtype::Kind::f:
+
20 return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
+
21 case Dtype::Kind::V:
+
22 return BNNSDataTypeBFloat16;
+
23 case Dtype::Kind::c:
+
24 throw std::invalid_argument("BNNS does not support complex types");
+
25 }
+
26}
+
+
27
+
28} // namespace mlx::core
+ +
Definition allocator.h:7
+
BNNSDataType to_bnns_dtype(Dtype mlx_dtype)
Definition utils.h:10
+
Dtype::Kind kindof(const Dtype &t)
+
uint8_t size_of(const Dtype &t)
Definition dtype.h:95
+
Definition dtype.h:15
+ + + + + + +
+ + + + diff --git a/docs/build/html/backend_2common_2ops_8h.html b/docs/build/html/backend_2common_2ops_8h.html new file mode 100644 index 000000000..c976bb466 --- /dev/null +++ b/docs/build/html/backend_2common_2ops_8h.html @@ -0,0 +1,234 @@ + + + + + + + +MLX: mlx/backend/common/ops.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
ops.h File Reference
+
+
+
#include <stdint.h>
+#include <cmath>
+#include <complex>
+
+

Go to the source code of this file.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Classes

union  mlx::core::detail::IntOrFloat
 
struct  mlx::core::detail::Abs
 
struct  mlx::core::detail::ArcCos
 
struct  mlx::core::detail::ArcCosh
 
struct  mlx::core::detail::ArcSin
 
struct  mlx::core::detail::ArcSinh
 
struct  mlx::core::detail::ArcTan
 
struct  mlx::core::detail::ArcTan2
 
struct  mlx::core::detail::ArcTanh
 
struct  mlx::core::detail::Ceil
 
struct  mlx::core::detail::Conjugate
 
struct  mlx::core::detail::Cos
 
struct  mlx::core::detail::Cosh
 
struct  mlx::core::detail::Erf
 
struct  mlx::core::detail::ErfInv
 
struct  mlx::core::detail::Exp
 
struct  mlx::core::detail::Expm1
 
struct  mlx::core::detail::Floor
 
struct  mlx::core::detail::Log
 
struct  mlx::core::detail::Log2
 
struct  mlx::core::detail::Log10
 
struct  mlx::core::detail::Log1p
 
struct  mlx::core::detail::LogicalNot
 
struct  mlx::core::detail::Negative
 
struct  mlx::core::detail::Round
 
struct  mlx::core::detail::Sigmoid
 
struct  mlx::core::detail::Sign
 
struct  mlx::core::detail::Sin
 
struct  mlx::core::detail::Sinh
 
struct  mlx::core::detail::Square
 
struct  mlx::core::detail::Sqrt
 
struct  mlx::core::detail::Rsqrt
 
struct  mlx::core::detail::Tan
 
struct  mlx::core::detail::Tanh
 
struct  mlx::core::detail::Add
 
struct  mlx::core::detail::Divide
 
struct  mlx::core::detail::Remainder
 
struct  mlx::core::detail::Equal
 
struct  mlx::core::detail::NaNEqual
 
struct  mlx::core::detail::Greater
 
struct  mlx::core::detail::GreaterEqual
 
struct  mlx::core::detail::Less
 
struct  mlx::core::detail::LessEqual
 
struct  mlx::core::detail::Maximum
 
struct  mlx::core::detail::Minimum
 
struct  mlx::core::detail::LogAddExp
 
struct  mlx::core::detail::Multiply
 
struct  mlx::core::detail::NotEqual
 
struct  mlx::core::detail::Power
 
struct  mlx::core::detail::Subtract
 
struct  mlx::core::detail::LogicalAnd
 
struct  mlx::core::detail::LogicalOr
 
struct  mlx::core::detail::Select
 
struct  mlx::core::detail::BitwiseAnd
 
struct  mlx::core::detail::BitwiseOr
 
struct  mlx::core::detail::BitwiseXor
 
struct  mlx::core::detail::LeftShift
 
struct  mlx::core::detail::RightShift
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::detail
 
+ + + + + + + +

+Functions

float mlx::core::detail::fast_exp (float x)
 
float mlx::core::detail::fast_erf (float a)
 
float mlx::core::detail::fast_erfinv (float a)
 
+
+ + + + diff --git a/docs/build/html/backend_2common_2ops_8h_source.html b/docs/build/html/backend_2common_2ops_8h_source.html new file mode 100644 index 000000000..1b6cebe96 --- /dev/null +++ b/docs/build/html/backend_2common_2ops_8h_source.html @@ -0,0 +1,1218 @@ + + + + + + + +MLX: mlx/backend/common/ops.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
ops.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4#include <stdint.h>
+
5#include <cmath>
+
6#include <complex>
+
7
+
+ +
9
+
10namespace {
+
11constexpr float inf = std::numeric_limits<float>::infinity();
+
12} // namespace
+
13
+
+
14typedef union {
+
15 int i;
+
16 float f;
+ +
+
18
+
+
19inline float fast_exp(float x) {
+
20 if (x == -std::numeric_limits<float>::infinity()) {
+
21 return 0.0f;
+
22 } else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
+
23 return x;
+
24 }
+
25 x *= 1.442695; // multiply with log_2(e)
+
26 float ipart, fpart;
+
27 IntOrFloat epart;
+
28 x = std::max(-80.f, std::min(x, 80.f));
+
29 ipart = std::floor(x + 0.5);
+
30 fpart = x - ipart;
+
31
+
32 x = 1.535336188319500e-4f;
+
33 x = x * fpart + 1.339887440266574e-3f;
+
34 x = x * fpart + 9.618437357674640e-3f;
+
35 x = x * fpart + 5.550332471162809e-2f;
+
36 x = x * fpart + 2.402264791363012e-1f;
+
37 x = x * fpart + 6.931472028550421e-1f;
+
38 x = x * fpart + 1.000000000000000f;
+
39
+
40 // generate 2**ipart in the floating point representation using integer
+
41 // bitshifting
+
42 epart.i = (int(ipart) + 127) << 23;
+
43
+
44 return epart.f * x;
+
45}
+
+
46
+
+
47inline float fast_erf(float a) {
+
48 float r, s, t, u;
+
49 t = std::abs(a);
+
50 s = a * a;
+
51 if (t > 0.927734375f) {
+
52 // maximum error 0.99527 ulp
+
53 r = std::fma(
+
54 -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
+
55 u = std::fma(
+
56 -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
+
57 r = std::fma(r, s, u);
+
58 r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
+
59 r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
+
60 r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
+
61 r = std::fma(r, t, -t);
+
62 // TODO, replace with expm1 when implemented
+
63 r = 1.0f - std::exp(r);
+
64 r = std::copysign(r, a);
+
65 } else {
+
66 // maximum error 0.98929 ulp
+
67 r = -5.96761703e-4f; // -0x1.38e000p-11
+
68 r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
+
69 r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
+
70 r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
+
71 r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
+
72 r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
+
73 r = std::fma(r, a, a);
+
74 }
+
75 return r;
+
76}
+
+
77
+
+
78inline float fast_erfinv(float a) {
+
79 auto t = std::fma(a, 0.0f - a, 1.0f);
+
80 t = std::log(t);
+
81 float p;
+
82 if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
+
83 p = 3.03697567e-10f; // 0x1.4deb44p-32
+
84 p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
+
85 p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
+
86 p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
+
87 p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
+
88 p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
+
89 p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
+
90 p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
+
91 p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
+
92 } else { // maximum ulp error = 2.35002
+
93 p = 5.43877832e-9f; // 0x1.75c000p-28
+
94 p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
+
95 p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
+
96 p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
+
97 p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
+
98 p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
+
99 p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
+
100 p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
+
101 p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
+
102 p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
+
103 }
+
104 return a * p;
+
105}
+
+
106
+
+
107struct Abs {
+
108 template <typename T>
+
+
109 T operator()(T x) {
+
110 return std::abs(x);
+
111 };
+
+
+
112 uint8_t operator()(uint8_t x) {
+
113 return x;
+
114 };
+
+
+
115 uint16_t operator()(uint16_t x) {
+
116 return x;
+
117 };
+
+
+
118 uint32_t operator()(uint32_t x) {
+
119 return x;
+
120 };
+
+
+
121 uint64_t operator()(uint64_t x) {
+
122 return x;
+
123 };
+
+
+
124 bool operator()(bool x) {
+
125 return x;
+
126 };
+
+
127};
+
+
128
+
+
129struct ArcCos {
+
130 template <typename T>
+
+
131 T operator()(T x) {
+
132 return std::acos(x);
+
133 };
+
+
134};
+
+
135
+
+
136struct ArcCosh {
+
137 template <typename T>
+
+
138 T operator()(T x) {
+
139 return std::acosh(x);
+
140 };
+
+
141};
+
+
142
+
+
143struct ArcSin {
+
144 template <typename T>
+
+
145 T operator()(T x) {
+
146 return std::asin(x);
+
147 };
+
+
148};
+
+
149
+
+
150struct ArcSinh {
+
151 template <typename T>
+
+
152 T operator()(T x) {
+
153 return std::asinh(x);
+
154 };
+
+
155};
+
+
156
+
+
157struct ArcTan {
+
158 template <typename T>
+
+
159 T operator()(T x) {
+
160 return std::atan(x);
+
161 };
+
+
162};
+
+
163
+
+
164struct ArcTan2 {
+
165 template <typename T>
+
+
166 T operator()(T y, T x) {
+
167 return std::atan2(y, x);
+
168 };
+
+
169};
+
+
170
+
+
171struct ArcTanh {
+
172 template <typename T>
+
+
173 T operator()(T x) {
+
174 return std::atanh(x);
+
175 };
+
+
176};
+
+
177
+
+
178struct Ceil {
+
179 template <typename T>
+
+
180 T operator()(T x) {
+
181 return std::ceil(x);
+
182 };
+
+
+
183 int8_t operator()(int8_t x) {
+
184 return x;
+
185 };
+
+
+
186 int16_t operator()(int16_t x) {
+
187 return x;
+
188 };
+
+
+
189 int32_t operator()(int32_t x) {
+
190 return x;
+
191 };
+
+
+
192 int64_t operator()(int64_t x) {
+
193 return x;
+
194 };
+
+
+
195 uint8_t operator()(uint8_t x) {
+
196 return x;
+
197 };
+
+
+
198 uint16_t operator()(uint16_t x) {
+
199 return x;
+
200 };
+
+
+
201 uint32_t operator()(uint32_t x) {
+
202 return x;
+
203 };
+
+
+
204 uint64_t operator()(uint64_t x) {
+
205 return x;
+
206 };
+
+
+
207 bool operator()(bool x) {
+
208 return x;
+
209 };
+
+
210};
+
+
211
+
+
212struct Conjugate {
+
+ +
214 return std::conj(x);
+
215 }
+
+
216};
+
+
217
+
+
218struct Cos {
+
219 template <typename T>
+
+
220 T operator()(T x) {
+
221 return std::cos(x);
+
222 };
+
+
223};
+
+
224
+
+
225struct Cosh {
+
226 template <typename T>
+
+
227 T operator()(T x) {
+
228 return std::cosh(x);
+
229 };
+
+
230};
+
+
231
+
+
232struct Erf {
+
233 template <typename T>
+
+
234 T operator()(T x) {
+
235 return static_cast<T>(fast_erf(static_cast<float>(x)));
+
236 };
+
+
237};
+
+
238
+
+
239struct ErfInv {
+
240 template <typename T>
+
+
241 T operator()(T x) {
+
242 return static_cast<T>(fast_erfinv(static_cast<float>(x)));
+
243 };
+
+
244};
+
+
245
+
+
246struct Exp {
+
247 template <typename T>
+
+
248 T operator()(T x) {
+
249 return fast_exp(x);
+
250 };
+
+
251
+
+ +
253 return std::exp(x);
+
254 }
+
+
255};
+
+
256
+
+
257struct Expm1 {
+
258 template <typename T>
+
+
259 T operator()(T x) {
+
260 return expm1(x);
+
261 };
+
+
262};
+
+
263
+
+
264struct Floor {
+
265 template <typename T>
+
+
266 T operator()(T x) {
+
267 return std::floor(x);
+
268 };
+
+
+
269 int8_t operator()(int8_t x) {
+
270 return x;
+
271 };
+
+
+
272 int16_t operator()(int16_t x) {
+
273 return x;
+
274 };
+
+
+
275 int32_t operator()(int32_t x) {
+
276 return x;
+
277 };
+
+
+
278 int64_t operator()(int64_t x) {
+
279 return x;
+
280 };
+
+
+
281 uint8_t operator()(uint8_t x) {
+
282 return x;
+
283 };
+
+
+
284 uint16_t operator()(uint16_t x) {
+
285 return x;
+
286 };
+
+
+
287 uint32_t operator()(uint32_t x) {
+
288 return x;
+
289 };
+
+
+
290 uint64_t operator()(uint64_t x) {
+
291 return x;
+
292 };
+
+
+
293 bool operator()(bool x) {
+
294 return x;
+
295 };
+
+
296};
+
+
297
+
+
298struct Log {
+
299 template <typename T>
+
+
300 T operator()(T x) {
+
301 return std::log(x);
+
302 };
+
+
303};
+
+
304
+
+
305struct Log2 {
+
306 template <typename T>
+
+
307 T operator()(T x) {
+
308 return std::log2(x);
+
309 };
+
+
310};
+
+
311
+
+
312struct Log10 {
+
313 template <typename T>
+
+
314 T operator()(T x) {
+
315 return std::log10(x);
+
316 };
+
+
317};
+
+
318
+
+
319struct Log1p {
+
320 template <typename T>
+
+
321 T operator()(T x) {
+
322 return log1p(x);
+
323 };
+
+
324};
+
+
325
+
+ +
327 template <typename T>
+
+
328 T operator()(T x) {
+
329 return !x;
+
330 };
+
+
331};
+
+
332
+
+
333struct Negative {
+
334 template <typename T>
+
+
335 T operator()(T x) {
+
336 return -x;
+
337 };
+
+
338};
+
+
339
+
+
340struct Round {
+
341 template <typename T>
+
+
342 T operator()(T x) {
+
343 return std::rint(x);
+
344 }
+
+
345
+
+ +
347 return {std::rint(x.real()), std::rint(x.imag())};
+
348 }
+
+
349};
+
+
350
+
+
351struct Sigmoid {
+
352 template <typename T>
+
+
353 T operator()(T x) {
+
354 auto one = static_cast<decltype(x)>(1.0);
+
355 return one / (one + fast_exp(-x));
+
356 }
+
+
357};
+
+
358
+
+
359struct Sign {
+
360 template <typename T>
+
+
361 T operator()(T x) {
+
362 return (x > T(0)) - (x < T(0));
+
363 }
+
+
+
364 uint8_t operator()(uint8_t x) {
+
365 return x != 0;
+
366 }
+
+
+
367 uint16_t operator()(uint16_t x) {
+
368 return x != 0;
+
369 }
+
+
+
370 uint32_t operator()(uint32_t x) {
+
371 return x != 0;
+
372 }
+
+
+
373 uint64_t operator()(uint64_t x) {
+
374 return x != 0;
+
375 }
+
+
376};
+
+
377
+
+
378struct Sin {
+
379 template <typename T>
+
+
380 T operator()(T x) {
+
381 return std::sin(x);
+
382 };
+
+
383};
+
+
384
+
+
385struct Sinh {
+
386 template <typename T>
+
+
387 T operator()(T x) {
+
388 return std::sinh(x);
+
389 };
+
+
390};
+
+
391
+
+
392struct Square {
+
393 template <typename T>
+
+
394 T operator()(T x) {
+
395 return x * x;
+
396 };
+
+
397};
+
+
398
+
+
399struct Sqrt {
+
400 template <typename T>
+
+
401 T operator()(T x) {
+
402 return std::sqrt(x);
+
403 };
+
+
404};
+
+
405
+
+
406struct Rsqrt {
+
407 template <typename T>
+
+
408 T operator()(T x) {
+
409 return static_cast<decltype(x)>(1.0) / std::sqrt(x);
+
410 };
+
+
411};
+
+
412
+
+
413struct Tan {
+
414 template <typename T>
+
+
415 T operator()(T x) {
+
416 return std::tan(x);
+
417 };
+
+
418};
+
+
419
+
+
420struct Tanh {
+
421 template <typename T>
+
+
422 T operator()(T x) {
+
423 return std::tanh(x);
+
424 };
+
+
425};
+
+
426
+
+
427struct Add {
+
428 template <typename T>
+
+
429 T operator()(T x, T y) {
+
430 return x + y;
+
431 }
+
+
432};
+
+
433
+
+
434struct Divide {
+
435 template <typename T>
+
+
436 T operator()(T x, T y) {
+
437 return x / y;
+
438 }
+
+
439};
+
+
440
+
+
441struct Remainder {
+
442 template <typename T>
+
+
443 std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
+
444 T numerator,
+
445 T denominator) {
+
446 return numerator % denominator;
+
447 }
+
+
448
+
449 template <typename T>
+
+
450 std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
+
451 T numerator,
+
452 T denominator) {
+
453 auto r = numerator % denominator;
+
454 if (r != 0 && (r < 0 != denominator < 0))
+
455 r += denominator;
+
456 return r;
+
457 }
+
+
458
+
459 template <typename T>
+
+
460 std::enable_if_t<!std::is_integral_v<T>, T> operator()(
+
461 T numerator,
+
462 T denominator) {
+
463 auto r = std::fmod(numerator, denominator);
+
464 if (r != 0 && (r < 0 != denominator < 0)) {
+
465 r += denominator;
+
466 }
+
467 return r;
+
468 }
+
+
469
+
+ +
471 return numerator % denominator;
+
472 }
+
+
473};
+
+
474
+
+
475struct Equal {
+
476 template <typename T>
+
+
477 bool operator()(T x, T y) {
+
478 return x == y;
+
479 }
+
+
480};
+
+
481
+
+
482struct NaNEqual {
+
483 template <typename T>
+
+
484 bool operator()(T x, T y) {
+
485 return x == y || (std::isnan(x) && std::isnan(y));
+
486 }
+
+
487};
+
+
488
+
+
489struct Greater {
+
490 template <typename T>
+
+
491 bool operator()(T x, T y) {
+
492 return x > y;
+
493 }
+
+
494};
+
+
495
+
+ +
497 template <typename T>
+
+
498 bool operator()(T x, T y) {
+
499 return x >= y;
+
500 }
+
+
501};
+
+
502
+
+
503struct Less {
+
504 template <typename T>
+
+
505 bool operator()(T x, T y) {
+
506 return x < y;
+
507 }
+
+
508};
+
+
509
+
+
510struct LessEqual {
+
511 template <typename T>
+
+
512 bool operator()(T x, T y) {
+
513 return x <= y;
+
514 }
+
+
515};
+
+
516
+
+
517struct Maximum {
+
518 template <typename T>
+
+
519 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
+
520 return (x > y) ? x : y;
+
521 }
+
+
522
+
523 template <typename T>
+
+
524 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
+
525 if (std::isnan(x)) {
+
526 return x;
+
527 }
+
528 return (x > y) ? x : y;
+
529 }
+
+
530};
+
+
531
+
+
532struct Minimum {
+
533 template <typename T>
+
+
534 std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
+
535 return x < y ? x : y;
+
536 }
+
+
537
+
538 template <typename T>
+
+
539 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
+
540 if (std::isnan(x)) {
+
541 return x;
+
542 }
+
543 return x < y ? x : y;
+
544 }
+
+
545};
+
+
546
+
+
547struct LogAddExp {
+
548 template <typename T>
+
+
549 T operator()(T x, T y) {
+
550 constexpr float inf = std::numeric_limits<float>::infinity();
+
551 auto maxval = Maximum()(x, y);
+
552 auto minval = Minimum()(x, y);
+
553 return (minval == -inf || maxval == inf)
+
554 ? maxval
+
555 : static_cast<decltype(x)>(
+
556 maxval + std::log1p(fast_exp(minval - maxval)));
+
557 };
+
+
558};
+
+
559
+
+
560struct Multiply {
+
561 template <typename T>
+
+
562 T operator()(T x, T y) {
+
563 return x * y;
+
564 }
+
+
565};
+
+
566
+
+
567struct NotEqual {
+
568 template <typename T>
+
+
569 bool operator()(T x, T y) {
+
570 return x != y;
+
571 }
+
+
572};
+
+
573
+
+
574struct Power {
+
575 template <typename T>
+
+
576 std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
+
577 return std::pow(base, exp);
+
578 }
+
+
579
+
580 template <typename T>
+
+
581 std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
+
582 T res = 1;
+
583 while (exp) {
+
584 if (exp & 1) {
+
585 res *= base;
+
586 }
+
587 exp >>= 1;
+
588 base *= base;
+
589 }
+
590 return res;
+
591 }
+
+
592};
+
+
593
+
+
594struct Subtract {
+
595 template <typename T>
+
+
596 T operator()(T x, T y) {
+
597 return x - y;
+
598 }
+
+
599};
+
+
600
+
+ +
602 template <typename T>
+
+
603 T operator()(T x, T y) {
+
604 return x && y;
+
605 };
+
+
606};
+
+
607
+
+
608struct LogicalOr {
+
609 template <typename T>
+
+
610 T operator()(T x, T y) {
+
611 return x || y;
+
612 };
+
+
613};
+
+
614
+
+
615struct Select {
+
616 template <typename T>
+
+
617 T operator()(bool condition, T x, T y) {
+
618 return condition ? x : y;
+
619 }
+
+
620};
+
+
621
+
+ +
623 template <typename T>
+
+
624 T operator()(T x, T y) {
+
625 return x & y;
+
626 };
+
+
627};
+
+
628
+
+
629struct BitwiseOr {
+
630 template <typename T>
+
+
631 T operator()(T x, T y) {
+
632 return x | y;
+
633 };
+
+
634};
+
+
635
+
+ +
637 template <typename T>
+
+
638 T operator()(T x, T y) {
+
639 return x ^ y;
+
640 };
+
+
641};
+
+
642
+
+
643struct LeftShift {
+
644 template <typename T>
+
+
645 T operator()(T x, T y) {
+
646 return x << y;
+
647 };
+
+
648};
+
+
649
+
+ +
651 template <typename T>
+
+
652 T operator()(T x, T y) {
+
653 return x >> y;
+
654 };
+
+
655};
+
+
656
+
657} // namespace mlx::core::detail
+
+
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
+
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
+
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
+
Definition ops.h:8
+
float fast_exp(float x)
Definition ops.h:19
+
float fast_erf(float a)
Definition ops.h:47
+
float fast_erfinv(float a)
Definition ops.h:78
+
Definition complex.h:34
+
Definition ops.h:107
+
T operator()(T x)
Definition ops.h:109
+
uint8_t operator()(uint8_t x)
Definition ops.h:112
+
uint64_t operator()(uint64_t x)
Definition ops.h:121
+
uint16_t operator()(uint16_t x)
Definition ops.h:115
+
bool operator()(bool x)
Definition ops.h:124
+
uint32_t operator()(uint32_t x)
Definition ops.h:118
+
Definition ops.h:427
+
T operator()(T x, T y)
Definition ops.h:429
+
Definition ops.h:129
+
T operator()(T x)
Definition ops.h:131
+
Definition ops.h:136
+
T operator()(T x)
Definition ops.h:138
+
Definition ops.h:143
+
T operator()(T x)
Definition ops.h:145
+
Definition ops.h:150
+
T operator()(T x)
Definition ops.h:152
+
Definition ops.h:164
+
T operator()(T y, T x)
Definition ops.h:166
+
Definition ops.h:157
+
T operator()(T x)
Definition ops.h:159
+
Definition ops.h:171
+
T operator()(T x)
Definition ops.h:173
+
Definition ops.h:622
+
T operator()(T x, T y)
Definition ops.h:624
+
Definition ops.h:629
+
T operator()(T x, T y)
Definition ops.h:631
+
Definition ops.h:636
+
T operator()(T x, T y)
Definition ops.h:638
+
Definition ops.h:178
+
uint8_t operator()(uint8_t x)
Definition ops.h:195
+
T operator()(T x)
Definition ops.h:180
+
uint32_t operator()(uint32_t x)
Definition ops.h:201
+
int8_t operator()(int8_t x)
Definition ops.h:183
+
int16_t operator()(int16_t x)
Definition ops.h:186
+
bool operator()(bool x)
Definition ops.h:207
+
uint16_t operator()(uint16_t x)
Definition ops.h:198
+
uint64_t operator()(uint64_t x)
Definition ops.h:204
+
int32_t operator()(int32_t x)
Definition ops.h:189
+
int64_t operator()(int64_t x)
Definition ops.h:192
+
Definition ops.h:212
+
complex64_t operator()(complex64_t x)
Definition ops.h:213
+
Definition ops.h:218
+
T operator()(T x)
Definition ops.h:220
+
Definition ops.h:225
+
T operator()(T x)
Definition ops.h:227
+
Definition ops.h:434
+
T operator()(T x, T y)
Definition ops.h:436
+
Definition ops.h:475
+
bool operator()(T x, T y)
Definition ops.h:477
+
Definition ops.h:232
+
T operator()(T x)
Definition ops.h:234
+
Definition ops.h:239
+
T operator()(T x)
Definition ops.h:241
+
Definition ops.h:246
+
T operator()(T x)
Definition ops.h:248
+
complex64_t operator()(complex64_t x)
Definition ops.h:252
+
Definition ops.h:257
+
T operator()(T x)
Definition ops.h:259
+
Definition ops.h:264
+
T operator()(T x)
Definition ops.h:266
+
uint32_t operator()(uint32_t x)
Definition ops.h:287
+
uint16_t operator()(uint16_t x)
Definition ops.h:284
+
uint8_t operator()(uint8_t x)
Definition ops.h:281
+
int32_t operator()(int32_t x)
Definition ops.h:275
+
int64_t operator()(int64_t x)
Definition ops.h:278
+
bool operator()(bool x)
Definition ops.h:293
+
int8_t operator()(int8_t x)
Definition ops.h:269
+
uint64_t operator()(uint64_t x)
Definition ops.h:290
+
int16_t operator()(int16_t x)
Definition ops.h:272
+ +
bool operator()(T x, T y)
Definition ops.h:498
+
Definition ops.h:489
+
bool operator()(T x, T y)
Definition ops.h:491
+
Definition ops.h:643
+
T operator()(T x, T y)
Definition ops.h:645
+
Definition ops.h:510
+
bool operator()(T x, T y)
Definition ops.h:512
+
Definition ops.h:503
+
bool operator()(T x, T y)
Definition ops.h:505
+
Definition ops.h:312
+
T operator()(T x)
Definition ops.h:314
+
Definition ops.h:319
+
T operator()(T x)
Definition ops.h:321
+
Definition ops.h:305
+
T operator()(T x)
Definition ops.h:307
+
Definition ops.h:547
+
T operator()(T x, T y)
Definition ops.h:549
+
Definition ops.h:298
+
T operator()(T x)
Definition ops.h:300
+
Definition ops.h:601
+
T operator()(T x, T y)
Definition ops.h:603
+
Definition ops.h:326
+
T operator()(T x)
Definition ops.h:328
+
Definition ops.h:608
+
T operator()(T x, T y)
Definition ops.h:610
+
Definition ops.h:517
+
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:519
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:524
+
Definition ops.h:532
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:539
+
std::enable_if_t< std::is_integral_v< T >, T > operator()(T x, T y)
Definition ops.h:534
+
Definition ops.h:560
+
T operator()(T x, T y)
Definition ops.h:562
+
Definition ops.h:482
+
bool operator()(T x, T y)
Definition ops.h:484
+
Definition ops.h:333
+
T operator()(T x)
Definition ops.h:335
+
Definition ops.h:567
+
bool operator()(T x, T y)
Definition ops.h:569
+
Definition ops.h:574
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:576
+
std::enable_if_t< std::is_integral_v< T >, T > operator()(T base, T exp)
Definition ops.h:581
+
Definition ops.h:441
+
std::enable_if_t<!std::is_integral_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:460
+
std::enable_if_t< std::is_integral_v< T > &!std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:443
+
std::enable_if_t< std::is_integral_v< T > &std::is_signed_v< T >, T > operator()(T numerator, T denominator)
Definition ops.h:450
+
complex64_t operator()(complex64_t numerator, complex64_t denominator)
Definition ops.h:470
+
Definition ops.h:650
+
T operator()(T x, T y)
Definition ops.h:652
+
Definition ops.h:340
+
T operator()(T x)
Definition ops.h:342
+
complex64_t operator()(complex64_t x)
Definition ops.h:346
+
Definition ops.h:406
+
T operator()(T x)
Definition ops.h:408
+
Definition ops.h:615
+
T operator()(bool condition, T x, T y)
Definition ops.h:617
+
Definition ops.h:351
+
T operator()(T x)
Definition ops.h:353
+
Definition ops.h:359
+
uint64_t operator()(uint64_t x)
Definition ops.h:373
+
T operator()(T x)
Definition ops.h:361
+
uint8_t operator()(uint8_t x)
Definition ops.h:364
+
uint16_t operator()(uint16_t x)
Definition ops.h:367
+
uint32_t operator()(uint32_t x)
Definition ops.h:370
+
Definition ops.h:378
+
T operator()(T x)
Definition ops.h:380
+
Definition ops.h:385
+
T operator()(T x)
Definition ops.h:387
+
Definition ops.h:399
+
T operator()(T x)
Definition ops.h:401
+
Definition ops.h:392
+
T operator()(T x)
Definition ops.h:394
+
Definition ops.h:594
+
T operator()(T x, T y)
Definition ops.h:596
+
Definition ops.h:413
+
T operator()(T x)
Definition ops.h:415
+
Definition ops.h:420
+
T operator()(T x)
Definition ops.h:422
+
uint32_t u
Definition bf16.h:17
+ +
float f
Definition ops.h:16
+
int i
Definition ops.h:15
+
+ + + + diff --git a/docs/build/html/backend_2common_2utils_8h.html b/docs/build/html/backend_2common_2utils_8h.html new file mode 100644 index 000000000..8789ebdcf --- /dev/null +++ b/docs/build/html/backend_2common_2utils_8h.html @@ -0,0 +1,121 @@ + + + + + + + +MLX: mlx/backend/common/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
utils.h File Reference
+
+
+
#include <vector>
+#include "mlx/array.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + + + + + + + + + + + + + + + +

+Functions

template<typename stride_t >
stride_t mlx::core::elem_to_loc (int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
 
size_t mlx::core::elem_to_loc (int elem, const array &a)
 
template<typename stride_t >
std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > mlx::core::collapse_contiguous_dims (const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)
 
std::tuple< std::vector< int >, std::vector< std::vector< size_t > > > mlx::core::collapse_contiguous_dims (const std::vector< array > &xs)
 
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
auto mlx::core::collapse_contiguous_dims (Arrays &&... xs)
 
template<typename stride_t >
auto mlx::core::check_contiguity (const std::vector< int > &shape, const std::vector< stride_t > &strides)
 
+
+ + + + diff --git a/docs/build/html/backend_2common_2utils_8h_source.html b/docs/build/html/backend_2common_2utils_8h_source.html new file mode 100644 index 000000000..b2b4439da --- /dev/null +++ b/docs/build/html/backend_2common_2utils_8h_source.html @@ -0,0 +1,236 @@ + + + + + + + +MLX: mlx/backend/common/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <vector>
+
6
+
7#include "mlx/array.h"
+
8
+
9namespace mlx::core {
+
10
+
11template <typename stride_t>
+
+
12inline stride_t elem_to_loc(
+
13 int elem,
+
14 const std::vector<int>& shape,
+
15 const std::vector<stride_t>& strides) {
+
16 stride_t loc = 0;
+
17 for (int i = shape.size() - 1; i >= 0; --i) {
+
18 auto q_and_r = ldiv(elem, shape[i]);
+
19 loc += q_and_r.rem * strides[i];
+
20 elem = q_and_r.quot;
+
21 }
+
22 return loc;
+
23}
+
+
24
+
+
25inline size_t elem_to_loc(int elem, const array& a) {
+
26 if (a.flags().row_contiguous) {
+
27 return elem;
+
28 }
+
29 return elem_to_loc(elem, a.shape(), a.strides());
+
30}
+
+
31
+
32// Collapse dims that are contiguous to possibly route to a better kernel
+
33// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
+
34// should return {{2, 4}, {{1, 2}}}.
+
35//
+
36// When multiple arrays are passed they should all have the same shape. The
+
37// collapsed axes are also the same so one shape is returned.
+
38template <typename stride_t>
+
39inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
+
+ +
41 const std::vector<int>& shape,
+
42 const std::vector<std::vector<stride_t>> strides) {
+
43 // Make a vector that has axes separated with -1. Collapse all axes between
+
44 // -1.
+
45 std::vector<int> to_collapse;
+
46 if (shape.size() > 0) {
+
47 to_collapse.push_back(0);
+
48 for (int i = 1; i < shape.size(); i++) {
+
49 bool contiguous = true;
+
50 for (const std::vector<stride_t>& st : strides) {
+
51 if (st[i] * shape[i] != st[i - 1]) {
+
52 contiguous = false;
+
53 }
+
54 if (!contiguous) {
+
55 break;
+
56 }
+
57 }
+
58 if (!contiguous) {
+
59 to_collapse.push_back(-1);
+
60 }
+
61 to_collapse.push_back(i);
+
62 }
+
63 to_collapse.push_back(-1);
+
64 }
+
65
+
66 std::vector<int> out_shape;
+
67 std::vector<std::vector<stride_t>> out_strides(strides.size());
+
68 for (int i = 0; i < to_collapse.size(); i++) {
+
69 int current_shape = shape[to_collapse[i]];
+
70 while (to_collapse[++i] != -1) {
+
71 current_shape *= shape[to_collapse[i]];
+
72 }
+
73 out_shape.push_back(current_shape);
+
74 for (int j = 0; j < strides.size(); j++) {
+
75 const std::vector<stride_t>& st = strides[j];
+
76 out_strides[j].push_back(st[to_collapse[i - 1]]);
+
77 }
+
78 }
+
79
+
80 return std::make_tuple(out_shape, out_strides);
+
81}
+
+
82
+
83inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
+
+
84collapse_contiguous_dims(const std::vector<array>& xs) {
+
85 std::vector<std::vector<size_t>> strides;
+
86 for (auto& x : xs) {
+
87 strides.emplace_back(x.strides());
+
88 }
+
89 return collapse_contiguous_dims(xs[0].shape(), strides);
+
90}
+
+
91
+
92template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
+
+
93inline auto collapse_contiguous_dims(Arrays&&... xs) {
+ +
95 std::vector<array>{std::forward<Arrays>(xs)...});
+
96}
+
+
97
+
98template <typename stride_t>
+
+
99inline auto check_contiguity(
+
100 const std::vector<int>& shape,
+
101 const std::vector<stride_t>& strides) {
+
102 size_t data_size = 1;
+
103 size_t f_stride = 1;
+
104 size_t b_stride = 1;
+
105 bool is_row_contiguous = true;
+
106 bool is_col_contiguous = true;
+
107
+
108 for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
+
109 is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
+
110 is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
+
111 f_stride *= shape[i];
+
112 b_stride *= shape[ri];
+
113 if (strides[i] > 0) {
+
114 data_size *= shape[i];
+
115 }
+
116 }
+
117
+
118 return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
+
119}
+
+
120
+
121} // namespace mlx::core
+ +
Definition array.h:20
+
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:290
+
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
+
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
+
Definition allocator.h:7
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
auto check_contiguity(const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:99
+
std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)
Definition utils.h:40
+
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:565
+
bool row_contiguous
Definition array.h:226
+
+ + + + diff --git a/docs/build/html/backend_2metal_2allocator_8h.html b/docs/build/html/backend_2metal_2allocator_8h.html new file mode 100644 index 000000000..a8d38c66a --- /dev/null +++ b/docs/build/html/backend_2metal_2allocator_8h.html @@ -0,0 +1,161 @@ + + + + + + + +MLX: mlx/backend/metal/allocator.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
allocator.h File Reference
+
+
+
#include <map>
+#include <mutex>
+#include <vector>
+#include "mlx/allocator.h"
+#include "mlx/backend/metal/device.h"
+
+

Go to the source code of this file.

+ + + + +

+Classes

class  mlx::core::metal::MetalAllocator
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::metal
 
+ + + +

+Functions

MetalAllocatormlx::core::metal::allocator ()
 
+

Variable Documentation

+ +

◆ buf

+ +
+
+ + + + +
MTL::Buffer* buf
+
+ +
+
+ +

◆ next

+ +
+
+ + + + +
BufferHolder* next
+
+ +
+
+ +

◆ prev

+ +
+
+ + + + +
BufferHolder* prev
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2allocator_8h_source.html b/docs/build/html/backend_2metal_2allocator_8h_source.html new file mode 100644 index 000000000..6588aac61 --- /dev/null +++ b/docs/build/html/backend_2metal_2allocator_8h_source.html @@ -0,0 +1,221 @@ + + + + + + + +MLX: mlx/backend/metal/allocator.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
allocator.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <map>
+
6#include <mutex>
+
7#include <vector>
+
8
+
9#include "mlx/allocator.h"
+ +
11
+
+ +
13
+ +
15
+
16namespace {
+
17
+
18class BufferCache {
+
19 public:
+
20 BufferCache(MTL::Device* device);
+
21 ~BufferCache();
+
22
+
23 MTL::Buffer* reuse_from_cache(size_t size);
+
24 void recycle_to_cache(MTL::Buffer* buf);
+
25 void release_cached_buffers(size_t min_bytes_to_free);
+
26 size_t cache_size() {
+
27 return pool_size_;
+
28 }
+
29 void clear();
+
30
+
31 private:
+
32 struct BufferHolder {
+
33 public:
+
34 BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
+
35
+
36 BufferHolder* prev;
+
37 BufferHolder* next;
+
38 MTL::Buffer* buf;
+
39 };
+
40
+
41 void add_at_head(BufferHolder* to_add);
+
42 void remove_from_list(BufferHolder* to_remove);
+
43
+
44 MTL::Device* device_;
+
45
+
46 std::multimap<size_t, BufferHolder*> buffer_pool_;
+
47 BufferHolder* head_;
+
48 BufferHolder* tail_;
+
49 size_t pool_size_;
+
50};
+
51
+
52} // namespace
+
53
+
+ +
56 public:
+
57 virtual Buffer malloc(size_t size, bool allow_swap = false) override;
+
58 virtual void free(Buffer buffer) override;
+
+ +
60 return active_memory_;
+
61 };
+
+
+
62 size_t get_peak_memory() {
+
63 return peak_memory_;
+
64 };
+
+
+ +
66 std::unique_lock lk(mutex_);
+
67 peak_memory_ = 0;
+
68 };
+
+
+ +
70 return buffer_cache_.cache_size();
+
71 };
+
+
72 size_t set_cache_limit(size_t limit);
+
73 size_t set_memory_limit(size_t limit, bool relaxed);
+ +
75
+
76 private:
+
77 MTL::Device* device_;
+ + +
80
+
81 // Caching allocator
+
82 BufferCache buffer_cache_;
+
83
+
84 // Allocation stats
+
85 size_t block_limit_;
+
86 size_t gc_limit_;
+
87 size_t active_memory_{0};
+
88 size_t peak_memory_{0};
+
89 size_t max_pool_size_;
+
90 bool relaxed_{true};
+
91
+
92 std::mutex mutex_;
+
93};
+
+
94
+ +
96
+
97} // namespace mlx::core::metal
+
+ +
MTL::Buffer * buf
Definition allocator.h:38
+
BufferHolder * prev
Definition allocator.h:36
+
BufferHolder * next
Definition allocator.h:37
+ +
Definition allocator.h:39
+
Definition allocator.h:12
+
Definition allocator.h:54
+
virtual void free(Buffer buffer) override
+
size_t set_memory_limit(size_t limit, bool relaxed)
+
void reset_peak_memory()
Definition allocator.h:65
+ +
virtual Buffer malloc(size_t size, bool allow_swap=false) override
Allocator for Metal GPUs.
+
size_t get_active_memory()
Definition allocator.h:59
+
size_t get_peak_memory()
Definition allocator.h:62
+
size_t get_cache_memory()
Definition allocator.h:69
+
size_t set_cache_limit(size_t limit)
+
friend MetalAllocator & allocator()
+
Definition allocator.h:12
+
MetalAllocator & allocator()
+
Device & device(mlx::core::Device)
+
+ + + + diff --git a/docs/build/html/backend_2metal_2device_8h.html b/docs/build/html/backend_2metal_2device_8h.html new file mode 100644 index 000000000..573f61f73 --- /dev/null +++ b/docs/build/html/backend_2metal_2device_8h.html @@ -0,0 +1,135 @@ + + + + + + + +MLX: mlx/backend/metal/device.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
device.h File Reference
+
+
+
#include <Metal/Metal.hpp>
+#include <functional>
+#include <mutex>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <dlfcn.h>
+#include <filesystem>
+#include "mlx/array.h"
+#include "mlx/device.h"
+
+

Go to the source code of this file.

+ + + + + + + + +

+Classes

struct  mlx::core::metal::CommandEncoder
 
struct  mlx::core::metal::CommandEncoder::ConcurrentContext
 
class  mlx::core::metal::Device
 
+ + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::metal
 
+ + + +

+Typedefs

using mlx::core::metal::MTLFCList
 
+ + + + + +

+Functions

std::string mlx::core::metal::get_colocated_mtllib_path (const std::string &lib_name)
 
Devicemlx::core::metal::device (mlx::core::Device)
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2device_8h_source.html b/docs/build/html/backend_2metal_2device_8h_source.html new file mode 100644 index 000000000..cac1cfce8 --- /dev/null +++ b/docs/build/html/backend_2metal_2device_8h_source.html @@ -0,0 +1,389 @@ + + + + + + + +MLX: mlx/backend/metal/device.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
device.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <Metal/Metal.hpp>
+
6#include <functional>
+
7#include <mutex>
+
8#include <string>
+
9#include <unordered_map>
+
10#include <unordered_set>
+
11
+
12#include <dlfcn.h>
+
13#include <filesystem>
+
14
+
15#include "mlx/array.h"
+
16#include "mlx/device.h"
+
17
+
18namespace fs = std::filesystem;
+
19
+
20namespace mlx::core::metal {
+
21
+
+
22inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
+
23 Dl_info info;
+
24 std::string mtllib_path;
+
25 std::string lib_ext = lib_name + ".metallib";
+
26
+
27 int success = dladdr((void*)get_colocated_mtllib_path, &info);
+
28 if (success) {
+
29 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
+
30 mtllib_path = mtllib.c_str();
+
31 }
+
32
+
33 return mtllib_path;
+
34}
+
+
35
+
36using MTLFCList =
+
37 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
+
38
+
+ +
+
40 CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
+
41 enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
+
42 enc->retain();
+
43 };
+
+ + +
46
+
+ +
+ +
49 enc.concurrent = true;
+
50 }
+
+
+ +
52 enc.concurrent = false;
+
53 enc.outputs.insert(
+
54 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
+
55 enc.concurrent_outputs.clear();
+
56 }
+
+
57
+
58 private:
+
59 CommandEncoder& enc;
+
60 };
+
+
61
+
+
62 MTL::ComputeCommandEncoder* operator->() {
+
63 return enc;
+
64 }
+
+
65
+
+
66 void set_input_array(const array& a, int idx, int offset = 0) {
+
67 auto r_buf =
+
68 static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
+
69 if (auto it = outputs.find(r_buf); it != outputs.end()) {
+
70 // Insert a barrier
+
71 enc->memoryBarrier(&r_buf, 1);
+
72
+
73 // Remove the output
+
74 outputs.erase(it);
+
75 }
+
76 auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
+
77 auto base_offset = a.data<char>() -
+
78 static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
+
79 base_offset += offset;
+
80 enc->setBuffer(a_buf, base_offset, idx);
+
81 }
+
+
82
+
+
83 void set_output_array(array& a, int idx, int offset = 0) {
+
84 // Add barriers before adding the output to the output set
+
85 set_input_array(a, idx, offset);
+
86 auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
+
87 if (concurrent) {
+
88 concurrent_outputs.insert(buf);
+
89 } else {
+
90 outputs.insert(buf);
+
91 }
+
92 }
+
+
93
+
94 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
+
95 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
+
96
+
+ +
98 return ConcurrentContext(*this);
+
99 }
+
+
100
+
+ +
102 enc->endEncoding();
+
103 enc->release();
+
104 }
+
+
105
+
106 private:
+
107 void maybe_split();
+
108
+
109 int num_dispatches{0};
+
110 MTL::CommandBuffer* cbuf;
+
111 MTL::ComputeCommandEncoder* enc;
+
112 bool concurrent{false};
+
113 std::unordered_set<MTL::Resource*> outputs;
+
114 std::unordered_set<MTL::Resource*> concurrent_outputs;
+
115};
+
+
116
+
+
117class Device {
+
118 public:
+ +
120 Device(const Device&) = delete;
+
121 Device& operator=(const Device&) = delete;
+ +
123
+
+
124 MTL::Device* mtl_device() {
+
125 return device_;
+
126 };
+
+
127
+
128 void new_queue(int index);
+
129 MTL::CommandBuffer* get_command_buffer(int index);
+ + +
132 void commit_command_buffer(int index);
+ +
134 void end_encoding(int index);
+
135
+ +
137 const std::string& lib_name,
+
138 const std::string& lib_path);
+ +
140 const std::string& lib_name,
+
141 const std::function<std::string(const std::string&)>& lib_path_func =
+ +
143
+
144 MTL::Library* get_library(const std::string& name);
+
145
+
146 MTL::Library* get_library(
+
147 const std::string& name,
+
148 const std::string& source_string,
+
149 bool cache = true);
+
150
+
151 MTL::Library* get_library(
+
152 const std::string& name,
+
153 const MTL::StitchedLibraryDescriptor* desc,
+
154 bool cache = true);
+
155
+
156 MTL::Function* get_function(
+
157 const std::string& base_name,
+
158 MTL::Library* mtl_lib,
+
159 const std::string& specialized_name = "",
+
160 const MTLFCList& func_consts = {});
+
161
+
162 MTL::Function* get_function(
+
163 const std::string& base_name,
+
164 const std::string& lib_name = "mlx",
+
165 const std::string& specialized_name = "",
+
166 const MTLFCList& func_consts = {});
+
167
+
168 MTL::ComputePipelineState* get_kernel(
+
169 const std::string& base_name,
+
170 MTL::Library* mtl_lib,
+
171 const std::string& hash_name = "",
+
172 const MTLFCList& func_consts = {},
+
173 const std::vector<MTL::Function*>& linked_functions = {});
+
174
+
175 MTL::ComputePipelineState* get_kernel(
+
176 const std::string& base_name,
+
177 const std::string& lib_name = "mlx",
+
178 const std::string& hash_name = "",
+
179 const MTLFCList& func_consts = {},
+
180 const std::vector<MTL::Function*>& linked_functions = {});
+
181
+
182 MTL::ArgumentEncoder* argument_encoder(
+
183 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
+
184
+
185 private:
+
186 MTL::Library* get_library_cache_(const std::string& name);
+
187
+
188 MTL::Library* get_library_(const std::string& source_string);
+
189 MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
+
190
+
191 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
+
192
+
193 MTL::Function* get_function_(
+
194 const std::string& name,
+
195 const std::string& specialized_name,
+
196 const MTLFCList& func_consts,
+
197 MTL::Library* mtl_lib);
+
198
+
199 MTL::LinkedFunctions* get_linked_functions_(
+
200 const std::vector<MTL::Function*>& funcs);
+
201
+
202 MTL::ComputePipelineState* get_kernel_(
+
203 const std::string& name,
+
204 const MTL::Function* mtl_function);
+
205
+
206 MTL::ComputePipelineState* get_kernel_(
+
207 const std::string& name,
+
208 const MTL::Function* mtl_function,
+
209 const MTL::LinkedFunctions* linked_functions);
+
210
+
211 MTL::Device* device_;
+
212 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
+
213 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
+
214 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
+
215 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
+
216 std::unordered_map<std::string, MTL::Library*> library_map_;
+
217 std::mutex mtx_;
+
218};
+
+
219
+ +
221
+
222} // namespace mlx::core::metal
+ +
MTL::Buffer * buf
Definition allocator.h:38
+
const void * ptr() const
Definition allocator.h:23
+
Definition array.h:20
+
T * data()
Definition array.h:313
+
allocator::Buffer & buffer()
Definition array.h:299
+
Definition device.h:117
+
int get_command_buffer_ops(int index)
+
MTL::Device * mtl_device()
Definition device.h:124
+
void register_library(const std::string &lib_name, const std::string &lib_path)
+ +
MTL::CommandBuffer * get_command_buffer(int index)
+
void end_encoding(int index)
+
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
+
void register_library(const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)
+
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
+
void increment_command_buffer_ops(int index)
+
void new_queue(int index)
+
MTL::Library * get_library(const std::string &name)
+
MTL::Library * get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
+
void commit_command_buffer(int index)
+
MTL::Library * get_library(const std::string &name, const std::string &source_string, bool cache=true)
+
MTL::Function * get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
+
Device(const Device &)=delete
+
MTL::Function * get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
+
Device & operator=(const Device &)=delete
+ +
MTL::ComputePipelineState * get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
+
CommandEncoder & get_command_encoder(int index)
+ +
Definition allocator.h:12
+
std::string get_colocated_mtllib_path(const std::string &lib_name)
Definition device.h:22
+
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:36
+
Device & device(mlx::core::Device)
+
Definition device.h:7
+ + +
ConcurrentContext(CommandEncoder &enc)
Definition device.h:48
+
Definition device.h:39
+
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
+
CommandEncoder(MTL::CommandBuffer *cbuf)
Definition device.h:40
+
CommandEncoder & operator=(const CommandEncoder &)=delete
+
ConcurrentContext start_concurrent()
Definition device.h:97
+
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
+
void set_input_array(const array &a, int idx, int offset=0)
Definition device.h:66
+
~CommandEncoder()
Definition device.h:101
+
MTL::ComputeCommandEncoder * operator->()
Definition device.h:62
+
CommandEncoder(const CommandEncoder &)=delete
+
void set_output_array(array &a, int idx, int offset=0)
Definition device.h:83
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2bf16_8h.html b/docs/build/html/backend_2metal_2kernels_2bf16_8h.html new file mode 100644 index 000000000..ccea11b13 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2bf16_8h.html @@ -0,0 +1,10952 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
bf16.h File Reference
+
+
+
#include <metal_stdlib>
+#include "mlx/backend/metal/kernels/bf16_math.h"
+
+

Go to the source code of this file.

+ + + + + + + + +

+Classes

struct  _MLX_BFloat16
 
struct  _MLX_BFloat16::bits_to_bfloat_struct
 
struct  metal::_numeric_limits_impl< bfloat16_t >
 
+ + + +

+Namespaces

namespace  metal
 
+ + + + + + + + + + + + + + + + + + + +

+Macros

#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype)
 
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype)
 
#define bfloat_binop(_op_, _operator_)
 
#define bfloat_compop(__op__, __operator__)
 
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space)
 
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype)
 
#define bfloat_inplace_op(itype)
 
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space)
 
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__)
 
+ + + +

+Typedefs

typedef struct _MLX_BFloat16 bfloat16_t
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

constexpr METAL_FUNC uint16_t float_to_bfloat_bits (float x)
 
constexpr METAL_FUNC float bfloat_bits_to_float (uint16_t x)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 x)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator+ (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator+ (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator- (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator- (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator* (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator* (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC float operator/ (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC float operator/ (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator> (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator> (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator> (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator> (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator> (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator> (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator< (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator< (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator< (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator< (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator< (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator< (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator>= (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator>= (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator>= (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator>= (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator>= (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator>= (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator<= (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator<= (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator<= (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator<= (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator<= (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator<= (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator== (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator== (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator== (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator== (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator== (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator== (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, float rhs)
 
constexpr METAL_FUNC bool operator!= (float lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, half rhs)
 
constexpr METAL_FUNC bool operator!= (half lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, int32_t rhs)
 
constexpr METAL_FUNC bool operator!= (int32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, uint32_t rhs)
 
constexpr METAL_FUNC bool operator!= (uint32_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, int64_t rhs)
 
constexpr METAL_FUNC bool operator!= (int64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs, uint64_t rhs)
 
constexpr METAL_FUNC bool operator!= (uint64_t lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator+= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator+= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator+= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator-= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator-= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator-= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator*= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator*= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator*= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC device float & operator/= (device float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC thread float & operator/= (thread float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, float rhs)
 
constexpr METAL_FUNC threadgroup float & operator/= (threadgroup float &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator+= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator+= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator+= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator-= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator-= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator-= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator*= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator*= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator*= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC device half & operator/= (device half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC thread half & operator/= (thread half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, half rhs)
 
constexpr METAL_FUNC threadgroup half & operator/= (threadgroup half &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator+= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator+= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator+= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator-= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator-= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator-= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator*= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator*= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator*= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC device int16_t & operator/= (device int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC thread int16_t & operator/= (thread int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, int16_t rhs)
 
constexpr METAL_FUNC threadgroup int16_t & operator/= (threadgroup int16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator+= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator+= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator+= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator-= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator-= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator-= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator*= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator*= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator*= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC device int32_t & operator/= (device int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC thread int32_t & operator/= (thread int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, int32_t rhs)
 
constexpr METAL_FUNC threadgroup int32_t & operator/= (threadgroup int32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator+= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator+= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator+= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator-= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator-= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator-= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator*= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator*= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator*= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC device int64_t & operator/= (device int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC thread int64_t & operator/= (thread int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, int64_t rhs)
 
constexpr METAL_FUNC threadgroup int64_t & operator/= (threadgroup int64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator+= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator+= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator+= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator-= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator-= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator-= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator*= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator*= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator*= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC device uint16_t & operator/= (device uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC thread uint16_t & operator/= (thread uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, uint16_t rhs)
 
constexpr METAL_FUNC threadgroup uint16_t & operator/= (threadgroup uint16_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator+= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator+= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator+= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator-= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator-= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator-= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator*= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator*= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator*= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC device uint32_t & operator/= (device uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC thread uint32_t & operator/= (thread uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, uint32_t rhs)
 
constexpr METAL_FUNC threadgroup uint32_t & operator/= (threadgroup uint32_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator+= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator+= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator+= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator-= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator-= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator-= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator*= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator*= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator*= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC device uint64_t & operator/= (device uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC thread uint64_t & operator/= (thread uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, uint64_t rhs)
 
constexpr METAL_FUNC threadgroup uint64_t & operator/= (threadgroup uint64_t &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator+= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator+= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator+= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator-= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator-= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator-= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator*= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator*= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator*= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC device _MLX_BFloat16operator/= (device _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC thread _MLX_BFloat16operator/= (thread _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
constexpr METAL_FUNC threadgroup _MLX_BFloat16operator/= (threadgroup _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
METAL_FUNC bool metal::isnan (_MLX_BFloat16 x)
 
+ + + + + + + +

+Variables

template<typename T >
static constexpr constant bool can_convert_to_bfloat
 
template<typename T >
static constexpr constant bool can_convert_from_bfloat
 
+

Macro Definition Documentation

+ +

◆ bfloat_binop

+ +
+
+ + + + + + + + + + + +
#define bfloat_binop( _op_,
_operator_ )
+
+Value:
+
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
+
bfloat_binop_helper(_op_, _operator_, float, float, float); \
+
bfloat_binop_helper(_op_, _operator_, float, half, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
+
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype)
Definition bf16.h:141
+
Definition bf16.h:54
+
+
+
+ +

◆ bfloat_binop_base

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
#define bfloat_binop_base( __op__,
__operator__,
otype,
atype,
btype,
ctype )
+
+Value:
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
}
+
+
+
+ +

◆ bfloat_binop_helper

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
#define bfloat_binop_helper( __op__,
__operator__,
otype,
itype,
ctype )
+
+Value:
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
} \
+
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
}
+
+
+
+ +

◆ bfloat_compop

+ +
+
+ + + + + + + + + + + +
#define bfloat_compop( __op__,
__operator__ )
+
+Value:
+
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
+
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
+
+
+
+ +

◆ bfloat_inplace_op

+ +
+
+ + + + + + + +
#define bfloat_inplace_op( itype)
+
+Value:
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
+
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
+
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
+
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
+
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype)
Definition bf16.h:209
+
+
+
+ +

◆ bfloat_inplace_op_addr_space_helper [1/2]

+ +
+
+ + + + + + + + + + + +
#define bfloat_inplace_op_addr_space_helper( __op__,
__operator__ )
+
+Value:
bfloat_inplace_op_helper(__op__, __operator__, device); \
+
bfloat_inplace_op_helper(__op__, __operator__, thread); \
+
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
+
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space)
Definition bf16.h:197
+
+
+
+ +

◆ bfloat_inplace_op_addr_space_helper [2/2]

+ +
+
+ + + + + + + + + + + + + + + + +
#define bfloat_inplace_op_addr_space_helper( __op__,
__operator__,
itype )
+
+Value:
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
+
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
+
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
+
+
+
+ +

◆ bfloat_inplace_op_helper [1/2]

+ +
+
+ + + + + + + + + + + + + + + + +
#define bfloat_inplace_op_helper( __op__,
__operator__,
addr_space )
+
+Value:
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
+
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
return lhs; \
+
}
+
+
+
+ +

◆ bfloat_inplace_op_helper [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
#define bfloat_inplace_op_helper( __op__,
__operator__,
itype,
addr_space )
+
+Value:
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
addr_space _MLX_BFloat16& lhs, itype rhs) { \
+
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
return lhs; \
+
} \
+
constexpr METAL_FUNC addr_space itype& __operator__( \
+
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
+
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
return lhs; \
+
}
+
+
+
+

Typedef Documentation

+ +

◆ bfloat16_t

+ +
+
+ + + + +
typedef struct _MLX_BFloat16 bfloat16_t
+
+ +
+
+

Function Documentation

+ +

◆ bfloat_bits_to_float()

+ +
+
+ + + + + +
+ + + + + + + +
constexpr METAL_FUNC float bfloat_bits_to_float (uint16_t x)
+
+constexpr
+
+ +
+
+ +

◆ float_to_bfloat_bits()

+ +
+
+ + + + + +
+ + + + + + + +
constexpr METAL_FUNC uint16_t float_to_bfloat_bits (float x)
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator!=() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator!= (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator* (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator* (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator*= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator*= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator*= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator*= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator*= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator*= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator*= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator*= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator*= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator*= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator*= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator*= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator*= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator*= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator*= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator*= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator*= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator*= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator*= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator*= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator*= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator*= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator*= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator*= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator*= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator*= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator*=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator*= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator+ (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator+ (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator+= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator+= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator+= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator+= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator+= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator+= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator+= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator+= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator+= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator+= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator+= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator+= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator+= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator+= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator+= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator+= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator+= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator+= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator+= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator+= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator+= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator+= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator+= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator+= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator+= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator+= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator+=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator+= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [1/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [2/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [3/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [4/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [5/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [6/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [7/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [8/14]

+ +
+
+ + + + + +
+ + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (_MLX_BFloat16 x)
+
+constexpr
+
+ +
+
+ +

◆ operator-() [9/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [10/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator- (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [11/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [12/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [13/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [14/14]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator- (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator-= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator-= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator-= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator-= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator-= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator-= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator-= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator-= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator-= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator-= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator-= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator-= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator-= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator-= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator-= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator-= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator-= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator-= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator-= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator-= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator-= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator-= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator-= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator-= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator-= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator-= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator-=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator-= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC float operator/ (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC _MLX_BFloat16 operator/ (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [1/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [2/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [3/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [4/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [5/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [6/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [7/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [8/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [9/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device _MLX_BFloat16 & operator/= (device _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [10/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device float & operator/= (device float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [11/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device half & operator/= (device half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [12/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int16_t & operator/= (device int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [13/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int32_t & operator/= (device int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [14/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device int64_t & operator/= (device int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [15/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint16_t & operator/= (device uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [16/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint32_t & operator/= (device uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [17/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC device uint64_t & operator/= (device uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [18/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [19/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [20/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [21/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [22/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [23/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [24/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [25/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [26/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread _MLX_BFloat16 & operator/= (thread _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [27/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread float & operator/= (thread float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [28/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread half & operator/= (thread half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [29/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int16_t & operator/= (thread int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [30/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int32_t & operator/= (thread int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [31/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread int64_t & operator/= (thread int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [32/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint16_t & operator/= (thread uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [33/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint32_t & operator/= (thread uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [34/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC thread uint64_t & operator/= (thread uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [35/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [36/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [37/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [38/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
int16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [39/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [40/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [41/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
uint16_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [42/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [43/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup _MLX_BFloat16 & operator/= (threadgroup _MLX_BFloat16 & lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [44/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup float & operator/= (threadgroup float & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [45/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup half & operator/= (threadgroup half & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [46/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int16_t & operator/= (threadgroup int16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [47/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int32_t & operator/= (threadgroup int32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [48/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup int64_t & operator/= (threadgroup int64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [49/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint16_t & operator/= (threadgroup uint16_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [50/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint32_t & operator/= (threadgroup uint32_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator/=() [51/51]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC threadgroup uint64_t & operator/= (threadgroup uint64_t & lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator< (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator<=() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator<= (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator==() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator== (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator> (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [1/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [2/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
float rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [3/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
half rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [4/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
int32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [5/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
int64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [6/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
uint32_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [7/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (_MLX_BFloat16 lhs,
uint64_t rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [8/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (float lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [9/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (half lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [10/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (int32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [11/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (int64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [12/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (uint32_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+ +

◆ operator>=() [13/13]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr METAL_FUNC bool operator>= (uint64_t lhs,
_MLX_BFloat16 rhs )
+
+constexpr
+
+ +
+
+

Variable Documentation

+ +

◆ can_convert_from_bfloat

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_from_bfloat
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>
+
+
+
+ +

◆ can_convert_to_bfloat

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_to_bfloat
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>
+
+
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2bf16_8h_source.html b/docs/build/html/backend_2metal_2kernels_2bf16_8h_source.html new file mode 100644 index 000000000..83ebc44f4 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2bf16_8h_source.html @@ -0,0 +1,489 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
bf16.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_stdlib>
+
6
+
7using namespace metal;
+
8
+
9#if defined(__HAVE_BFLOAT__)
+
10
+
11typedef bfloat bfloat16_t;
+
12
+
13#else
+
14
+
16// Helpers
+
18
+
+
19constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
+
20 // Check for nan
+
21 if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
+
22 _fp_encoding_traits<float>::inf_mask) {
+
23 return uint16_t(as_type<uint32_t>(0x7FC0));
+
24 }
+
25 // Take bits
+
26 uint32_t float_bits = as_type<uint32_t>(x);
+
27
+
28 // Round to nearest even
+
29 float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
+
30
+
31 // Take upper 16 bits
+
32 return float_bits >> 16;
+
33}
+
+
34
+
+
35constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
+
36 // Upper 16 bits are the data and lower 16 bits are 0s
+
37 return as_type<float>((uint32_t)x << 16);
+
38}
+
+
39
+
40struct _MLX_BFloat16;
+
41
+
42template <typename T>
+
43static constexpr constant bool can_convert_to_bfloat =
+
44 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
+
45
+
46template <typename T>
+
47static constexpr constant bool can_convert_from_bfloat =
+
48 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
+
49
+
51// Bfloat struct
+
53
+
+ +
56 // Constructors
+
57 uint16_t bits_;
+
58 _MLX_BFloat16() thread = default;
+
59 _MLX_BFloat16() threadgroup = default;
+
60 _MLX_BFloat16() device = default;
+
61 _MLX_BFloat16() constant = default;
+
62
+ +
+
64 static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
+
65 return bits_to_bfloat_struct();
+
66 }
+
+
+
67 constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
+
68 : bits_(bits) {}
+
+
69
+
71 // Conversions to bfloat
+
72
+
73 template <
+
74 typename T,
+
75 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
76 constexpr METAL_FUNC _MLX_BFloat16(T x) thread
+
77 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
78
+
79 template <
+
80 typename T,
+
81 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
82 constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
+
83 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
84
+
85 template <
+
86 typename T,
+
87 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
88 constexpr METAL_FUNC _MLX_BFloat16(T x) device
+
89 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
90
+
91 template <
+
92 typename T,
+
93 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
+
+
94 constexpr METAL_FUNC _MLX_BFloat16(T x) constant
+
95 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
+
+
96
+
98 // Conversions from bfloat
+
99
+
100 template <
+
101 typename T,
+
102 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
103 constexpr METAL_FUNC operator T() const thread {
+
104 return static_cast<T>(bfloat_bits_to_float(bits_));
+
105 }
+
+
106
+
107 template <
+
108 typename T,
+
109 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
110 constexpr METAL_FUNC operator T() const threadgroup {
+
111 return static_cast<T>(bfloat_bits_to_float(bits_));
+
112 }
+
+
113
+
114 template <
+
115 typename T,
+
116 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
117 constexpr METAL_FUNC operator T() const device {
+
118 return static_cast<T>(bfloat_bits_to_float(bits_));
+
119 }
+
+
120
+
121 template <
+
122 typename T,
+
123 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
+
+
124 constexpr METAL_FUNC operator T() const constant {
+
125 return static_cast<T>(bfloat_bits_to_float(bits_));
+
126 }
+
+
127};
+
+
128
+
130// Bfloat operators
+
132
+
134// Unary ops
+
+
135constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
+
136 return -static_cast<float>(x);
+
137}
+
+
138
+
140// Binary operators
+
+
141#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
+
142 constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
+
143 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
144 }
+
+
145
+
+
146#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
+
147 constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
+
148 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
149 } \
+
150 constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
+
151 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
+
152 }
+
+
153
+
155// Arithmetic Operators
+
+
156#define bfloat_binop(_op_, _operator_) \
+
157 bfloat_binop_base( \
+
158 _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
+
159 bfloat_binop_helper(_op_, _operator_, float, float, float); \
+
160 bfloat_binop_helper(_op_, _operator_, float, half, float); \
+
161 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
+
162 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
+
163 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
+
164 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
+
+
165
+
166bfloat_binop(+, operator+);
+
167bfloat_binop(-, operator-);
+
168bfloat_binop(*, operator*);
+
169bfloat_binop(/, operator/);
+
170
+
172// Comparison ops
+
+
173#define bfloat_compop(__op__, __operator__) \
+
174 bfloat_binop_base( \
+
175 __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
+
176 bfloat_binop_helper(__op__, __operator__, bool, float, float); \
+
177 bfloat_binop_helper(__op__, __operator__, bool, half, float); \
+
178 bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
+
179 bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
+
180 bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
+
181 bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
+
+
182
+
183bfloat_compop(>, operator>);
+
184bfloat_compop(<, operator<);
+
185bfloat_compop(>=, operator>=);
+
186bfloat_compop(<=, operator<=);
+
187bfloat_compop(==, operator==);
+
188bfloat_compop(!=, operator!=);
+
189
+
190#undef bfloat_compop
+
191#undef bfloat_binop_base
+
192#undef bfloat_binop_helper
+
193#undef bfloat_binop
+
194
+
196// Inplace Operators
+
+
197#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
+
198 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
199 addr_space _MLX_BFloat16& lhs, itype rhs) { \
+
200 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
201 return lhs; \
+
202 } \
+
203 constexpr METAL_FUNC addr_space itype& __operator__( \
+
204 addr_space itype& lhs, _MLX_BFloat16 rhs) { \
+
205 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
206 return lhs; \
+
207 }
+
+
208
+
+
209#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
+
210 bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
+
211 bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
+
212 bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
+
+
213
+
+
214#define bfloat_inplace_op(itype) \
+
215 bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
+
216 bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
+
217 bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
+
218 bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
+
+
219
+ + + + + + + + +
228
+
229#undef bfloat_inplace_op_helper
+
230#undef bfloat_inplace_op_addr_space_helper
+
231#undef bfloat_inplace_op
+
232
+
233#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
+
234 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
+
235 addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
+
236 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
+
237 return lhs; \
+
238 }
+
239
+
240#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
+
241 bfloat_inplace_op_helper(__op__, __operator__, device); \
+
242 bfloat_inplace_op_helper(__op__, __operator__, thread); \
+
243 bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
+
244
+ + + + +
249
+
250#undef bfloat_inplace_op_helper
+
251#undef bfloat_inplace_op_addr_space_helper
+
252
+
254// Bfloat typedef
+
256
+ +
258
+
260// Bfloat numeric limits
+
262
+
263#pragma METAL internals : enable
+
264
+
+
265namespace metal {
+
266
+
267template <>
+
+
268struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
+
269 static constexpr constant int digits = 8;
+
270 static constexpr constant int digits10 = 2;
+
271 static constexpr constant int max_digits10 = 4;
+
272 static constexpr constant int radix = 2;
+
273 static constexpr constant int min_exponent = -125;
+
274 static constexpr constant int min_exponent10 = -37;
+
275 static constexpr constant int max_exponent = 128;
+
276 static constexpr constant int max_exponent10 = 38;
+
277
+
+
278 static constexpr bfloat16_t min() {
+ +
280 }
+
+
+
281 static constexpr bfloat16_t lowest() {
+ +
283 }
+
+
+
284 static constexpr bfloat16_t max() {
+ +
286 }
+
+
+
287 static constexpr bfloat16_t epsilon() {
+ +
289 }
+
+
+
290 static constexpr bfloat16_t round_error() {
+ +
292 }
+
+
+
293 static constexpr bfloat16_t infinity() {
+ +
295 }
+
+
+
296 static constexpr bfloat16_t quiet_NaN() {
+ +
298 }
+
+
+
299 static constexpr bfloat16_t signaling_NaN() {
+ +
301 }
+
+
+
302 static constexpr bfloat16_t denorm_min() {
+ +
304 }
+
+
305};
+
+
306
+
+
307METAL_FUNC bool isnan(_MLX_BFloat16 x) {
+
308 return x != x;
+
309}
+
+
310
+
311} // namespace metal
+
+
312
+
313#pragma METAL internals : disable
+
314
+
315#endif // defined(__HAVE_BFLOAT__)
+
316
+ +
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x)
Definition bf16.h:19
+
#define bfloat_compop(__op__, __operator__)
Definition bf16.h:173
+
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x)
Definition bf16.h:35
+
#define bfloat_inplace_op(itype)
Definition bf16.h:214
+
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x)
Definition bf16.h:135
+
#define bfloat_binop(_op_, _operator_)
Definition bf16.h:156
+
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
+
static constexpr constant bool can_convert_from_bfloat
Definition bf16.h:47
+
static constexpr constant bool can_convert_to_bfloat
Definition bf16.h:43
+
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype)
Definition bf16.h:209
+ +
Definition bf16.h:265
+
METAL_FUNC bool isnan(_MLX_BFloat16 x)
Definition bf16.h:307
+ +
Definition bf16.h:54
+
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
Definition bf16.h:76
+
uint16_t bits_
Definition bf16.h:57
+
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
Definition bf16.h:67
+
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat()
Definition bf16.h:64
+
_MLX_BFloat16() thread=default
+
constexpr METAL_FUNC _MLX_BFloat16(T x) device
Definition bf16.h:88
+
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
Definition bf16.h:82
+
constexpr METAL_FUNC _MLX_BFloat16(T x) const ant
Definition bf16.h:94
+
static constexpr bfloat16_t infinity()
Definition bf16.h:293
+
static constexpr bfloat16_t denorm_min()
Definition bf16.h:302
+
static constexpr bfloat16_t max()
Definition bf16.h:284
+
static constexpr bfloat16_t epsilon()
Definition bf16.h:287
+
static constexpr bfloat16_t signaling_NaN()
Definition bf16.h:299
+
static constexpr bfloat16_t min()
Definition bf16.h:278
+
static constexpr bfloat16_t lowest()
Definition bf16.h:281
+
static constexpr bfloat16_t quiet_NaN()
Definition bf16.h:296
+
static constexpr bfloat16_t round_error()
Definition bf16.h:290
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2complex_8h.html b/docs/build/html/backend_2metal_2kernels_2complex_8h.html new file mode 100644 index 000000000..f45ed8b0a --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2complex_8h.html @@ -0,0 +1,504 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/complex.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
complex.h File Reference
+
+
+
#include <metal_stdlib>
+
+

Go to the source code of this file.

+ + + + +

+Classes

struct  complex64_t
 
+ + + + + + + + + + + + + + + + + + + + + + + +

+Functions

constexpr complex64_t operator- (complex64_t x)
 
constexpr bool operator>= (complex64_t a, complex64_t b)
 
constexpr bool operator> (complex64_t a, complex64_t b)
 
constexpr bool operator<= (complex64_t a, complex64_t b)
 
constexpr bool operator< (complex64_t a, complex64_t b)
 
constexpr bool operator== (complex64_t a, complex64_t b)
 
constexpr complex64_t operator+ (complex64_t a, complex64_t b)
 
constexpr complex64_t operator- (complex64_t a, complex64_t b)
 
constexpr complex64_t operator* (complex64_t a, complex64_t b)
 
constexpr complex64_t operator/ (complex64_t a, complex64_t b)
 
constexpr complex64_t operator% (complex64_t a, complex64_t b)
 
+ + + + + + + +

+Variables

template<typename T >
static constexpr constant bool can_convert_to_complex64
 
template<typename T >
static constexpr constant bool can_convert_from_complex64
 
+

Function Documentation

+ +

◆ operator%()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator% (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator*()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator* (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator+()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator+ (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator- (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator-() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
constexpr complex64_t operator- (complex64_t x)
+
+constexpr
+
+ +
+
+ +

◆ operator/()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr complex64_t operator/ (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator<()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator< (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator<=()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator<= (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator==()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator== (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator>()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator> (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+ +

◆ operator>=()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
constexpr bool operator>= (complex64_t a,
complex64_t b )
+
+constexpr
+
+ +
+
+

Variable Documentation

+ +

◆ can_convert_from_complex64

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_from_complex64
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, complex64_t> &&
+
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>)
+
+
+
+ +

◆ can_convert_to_complex64

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
constexpr constant bool can_convert_to_complex64
+
+staticconstexpr
+
+Initial value:
=
+
!is_same_v<T, complex64_t> && is_convertible_v<T, float>
+
+
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2complex_8h_source.html b/docs/build/html/backend_2metal_2kernels_2complex_8h_source.html new file mode 100644 index 000000000..c47bbd109 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2complex_8h_source.html @@ -0,0 +1,276 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/complex.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
complex.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_stdlib>
+
6
+
7using namespace metal;
+
8
+
9struct complex64_t;
+
10
+
11template <typename T>
+
12static constexpr constant bool can_convert_to_complex64 =
+
13 !is_same_v<T, complex64_t> && is_convertible_v<T, float>;
+
14
+
15template <typename T>
+
16static constexpr constant bool can_convert_from_complex64 =
+
17 !is_same_v<T, complex64_t> &&
+
18 (is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
+
19
+
+ +
21 float real;
+
22 float imag;
+
23
+
24 // Constructors
+
25 constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
+
26
+
27 // Conversions to complex64_t
+
28 template <
+
29 typename T,
+
30 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
31 constexpr complex64_t(T x) thread : real(x), imag(0) {}
+
32
+
33 template <
+
34 typename T,
+
35 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
36 constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
+
37
+
38 template <
+
39 typename T,
+
40 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
41 constexpr complex64_t(T x) device : real(x), imag(0) {}
+
42
+
43 template <
+
44 typename T,
+
45 typename = typename enable_if<can_convert_to_complex64<T>>::type>
+
46 constexpr complex64_t(T x) constant : real(x), imag(0) {}
+
47
+
48 // Conversions from complex64_t
+
49 template <
+
50 typename T,
+
51 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
52 constexpr operator T() const thread {
+
53 return static_cast<T>(real);
+
54 }
+
+
55
+
56 template <
+
57 typename T,
+
58 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
59 constexpr operator T() const threadgroup {
+
60 return static_cast<T>(real);
+
61 }
+
+
62
+
63 template <
+
64 typename T,
+
65 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
66 constexpr operator T() const device {
+
67 return static_cast<T>(real);
+
68 }
+
+
69
+
70 template <
+
71 typename T,
+
72 typename = typename enable_if<can_convert_from_complex64<T>>::type>
+
+
73 constexpr operator T() const constant {
+
74 return static_cast<T>(real);
+
75 }
+
+
76};
+
+
77
+
+ +
79 return {-x.real, -x.imag};
+
80}
+
+
81
+
+
82constexpr bool operator>=(complex64_t a, complex64_t b) {
+
83 return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
+
84}
+
+
85
+
+
86constexpr bool operator>(complex64_t a, complex64_t b) {
+
87 return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
+
88}
+
+
89
+
+
90constexpr bool operator<=(complex64_t a, complex64_t b) {
+
91 return operator>=(b, a);
+
92}
+
+
93
+
+
94constexpr bool operator<(complex64_t a, complex64_t b) {
+
95 return operator>(b, a);
+
96}
+
+
97
+
+
98constexpr bool operator==(complex64_t a, complex64_t b) {
+
99 return a.real == b.real && a.imag == b.imag;
+
100}
+
+
101
+
+ +
103 return {a.real + b.real, a.imag + b.imag};
+
104}
+
+
105
+
+ +
107 return {a.real - b.real, a.imag - b.imag};
+
108}
+
+
109
+
+ +
111 return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
+
112}
+
+
113
+
+ +
115 auto denom = b.real * b.real + b.imag * b.imag;
+
116 auto x = a.real * b.real + a.imag * b.imag;
+
117 auto y = a.imag * b.real - a.real * b.imag;
+
118 return {x / denom, y / denom};
+
119}
+
+
120
+
+ +
122 auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
+
123 auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
+
124 if (real != 0 && (real < 0 != b.real < 0)) {
+
125 real += b.real;
+
126 }
+
127 if (imag != 0 && (imag < 0 != b.imag < 0)) {
+
128 imag += b.imag;
+
129 }
+
130 return {real, imag};
+
131}
+
+
constexpr bool operator>(complex64_t a, complex64_t b)
Definition complex.h:86
+
constexpr complex64_t operator-(complex64_t x)
Definition complex.h:78
+
static constexpr constant bool can_convert_to_complex64
Definition complex.h:12
+
constexpr bool operator<(complex64_t a, complex64_t b)
Definition complex.h:94
+
constexpr complex64_t operator*(complex64_t a, complex64_t b)
Definition complex.h:110
+
constexpr complex64_t operator%(complex64_t a, complex64_t b)
Definition complex.h:121
+
constexpr bool operator>=(complex64_t a, complex64_t b)
Definition complex.h:82
+
static constexpr constant bool can_convert_from_complex64
Definition complex.h:16
+
constexpr bool operator==(complex64_t a, complex64_t b)
Definition complex.h:98
+
constexpr complex64_t operator+(complex64_t a, complex64_t b)
Definition complex.h:102
+
constexpr complex64_t operator/(complex64_t a, complex64_t b)
Definition complex.h:114
+
constexpr bool operator<=(complex64_t a, complex64_t b)
Definition complex.h:90
+
Definition bf16.h:265
+
Definition complex.h:20
+
constexpr complex64_t(T x) const ant
Definition complex.h:46
+
constexpr complex64_t(T x) thread
Definition complex.h:31
+
constexpr complex64_t(T x) threadgroup
Definition complex.h:36
+
float imag
Definition complex.h:22
+
float real
Definition complex.h:21
+
constexpr complex64_t(T x) device
Definition complex.h:41
+
constexpr complex64_t(float real, float imag)
Definition complex.h:25
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h.html b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h.html new file mode 100644 index 000000000..d991b9612 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h.html @@ -0,0 +1,116 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/ops.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
ops.h File Reference
+
+
+
#include <metal_atomic>
+#include <metal_simdgroup>
+#include "mlx/backend/metal/kernels/atomic.h"
+#include "mlx/backend/metal/kernels/bf16.h"
+#include "mlx/backend/metal/kernels/utils.h"
+
+

Go to the source code of this file.

+ + + + + + + + + + + + + + + + + + +

+Classes

union  bool4_or_uint
 
struct  None
 
struct  And
 
struct  Or
 
struct  Sum< U >
 
struct  Prod< U >
 
struct  Min< U >
 
struct  Max< U >
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h_source.html b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h_source.html new file mode 100644 index 000000000..58bc01ea0 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2ops_8h_source.html @@ -0,0 +1,385 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/ops.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
ops.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_atomic>
+
6#include <metal_simdgroup>
+
7
+ + + +
11
+
+ +
13 bool4 b;
+
14 unsigned int i;
+
15};
+
+
16
+
+
17struct None {
+
18 template <typename T>
+
+
19 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
20 mlx_atomic_store_explicit(out, val, offset);
+
21 }
+
+
22};
+
+
23
+
+
24struct And {
+
+
25 bool simd_reduce(bool val) {
+
26 return simd_all(val);
+
27 };
+
+
28
+
29 static constexpr constant bool init = true;
+
30
+
+ +
32 device mlx_atomic<unsigned int>* out,
+
33 bool val,
+
34 int elem_idx,
+
35 int offset = 0) {
+
36 if (!val) {
+ +
38 update.b = {true, true, true, true};
+
39 update.b[elem_idx] = false;
+ +
41 }
+
42 }
+
+
43
+
+
44 void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
+
45 if (!val) {
+
46 mlx_atomic_store_explicit(out, val, offset);
+
47 }
+
48 }
+
+
49
+
50 // Non atomic update
+
+
51 void update(device bool* out, bool val) {
+
52 *out &= val;
+
53 }
+
+
54
+
55 // Operator
+
+
56 bool operator()(bool a, bool b) {
+
57 return a && b;
+
58 }
+
+
59};
+
+
60
+
+
61struct Or {
+
+
62 bool simd_reduce(bool val) {
+
63 return simd_any(val);
+
64 };
+
+
65
+
66 static constexpr constant bool init = false;
+
67
+
+ +
69 device mlx_atomic<unsigned int>* out,
+
70 bool val,
+
71 uint elem_idx,
+
72 uint offset = 0) {
+
73 if (val) {
+ +
75 update.b = {false, false, false, false};
+
76 update.b[elem_idx] = true;
+ +
78 }
+
79 }
+
+
80
+
+
81 void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
+
82 if (val) {
+
83 mlx_atomic_store_explicit(out, val, offset);
+
84 }
+
85 }
+
+
86
+
87 // Non atomic update
+
+
88 void update(device bool* out, bool val) {
+
89 *out |= val;
+
90 }
+
+
91
+
92 // Operator
+
+
93 bool operator()(bool a, bool b) {
+
94 return a || b;
+
95 }
+
+
96};
+
+
97
+
98template <typename U>
+
+
99struct Sum {
+
100 template <typename T>
+
+
101 T simd_reduce(T val) {
+
102 return simd_sum(val);
+
103 };
+
+
104
+
105 static constexpr constant U init = U(0);
+
106
+
107 template <typename T>
+
+
108 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
109 mlx_atomic_fetch_add_explicit(out, val, offset);
+
110 }
+
+
111
+
112 // Operator
+
+
113 U operator()(U a, U b) {
+
114 return a + b;
+
115 }
+
+
116};
+
+
117
+
118template <typename U>
+
+
119struct Prod {
+
120 template <typename T>
+
+
121 T simd_reduce(T val) {
+
122 return simd_product(val);
+
123 };
+
+
124
+
125 static constexpr constant U init = U(1);
+
126
+
127 template <typename T>
+
+
128 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
129 mlx_atomic_fetch_mul_explicit(out, val, offset);
+
130 }
+
+
131
+
132 // Operator
+
+
133 U operator()(U a, U b) {
+
134 return a * b;
+
135 }
+
+
136};
+
+
137
+
138template <typename U>
+
+
139struct Min {
+
140 template <typename T>
+
+
141 T simd_reduce(T val) {
+
142 return simd_min(val);
+
143 };
+
+
144
+
145 static constexpr constant U init = Limits<U>::max;
+
146
+
147 template <typename T>
+
+
148 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
149 mlx_atomic_fetch_min_explicit(out, val, offset);
+
150 }
+
+
151
+
152 // Operator
+
+
153 U operator()(U a, U b) {
+
154 return a < b ? a : b;
+
155 }
+
+
156};
+
+
157
+
158template <typename U>
+
+
159struct Max {
+
160 template <typename T>
+
+
161 T simd_reduce(T val) {
+
162 return simd_max(val);
+
163 };
+
+
164
+
165 static constexpr constant U init = Limits<U>::min;
+
166
+
167 template <typename T>
+
+
168 void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
+
169 mlx_atomic_fetch_max_explicit(out, val, offset);
+
170 }
+
+
171
+
172 // Operator
+
+
173 U operator()(U a, U b) {
+
174 return a > b ? a : b;
+
175 }
+
+
176};
+
+ +
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:82
+
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:52
+
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:47
+
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:61
+
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:74
+
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:66
+
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:90
+ + +
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:392
+
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:392
+
METAL_FUNC bfloat16_t simd_product(bfloat16_t data)
Definition bf16_math.h:392
+
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:392
+
Definition ops.h:24
+
void atomic_update(device mlx_atomic< bool > *out, bool val, uint offset=0)
Definition ops.h:44
+
bool operator()(bool a, bool b)
Definition ops.h:56
+
bool simd_reduce(bool val)
Definition ops.h:25
+
static constexpr constant bool init
Definition ops.h:29
+
void atomic_update(device mlx_atomic< unsigned int > *out, bool val, int elem_idx, int offset=0)
Definition ops.h:31
+
void update(device bool *out, bool val)
Definition ops.h:51
+
Definition utils.h:14
+
Definition ops.h:159
+
T simd_reduce(T val)
Definition ops.h:161
+
U operator()(U a, U b)
Definition ops.h:173
+
static constexpr constant U init
Definition ops.h:165
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:168
+
Definition ops.h:139
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:148
+
U operator()(U a, U b)
Definition ops.h:153
+
static constexpr constant U init
Definition ops.h:145
+
T simd_reduce(T val)
Definition ops.h:141
+
Definition ops.h:17
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:19
+
Definition ops.h:61
+
void atomic_update(device mlx_atomic< bool > *out, bool val, uint offset=0)
Definition ops.h:81
+
bool operator()(bool a, bool b)
Definition ops.h:93
+
void atomic_update(device mlx_atomic< unsigned int > *out, bool val, uint elem_idx, uint offset=0)
Definition ops.h:68
+
void update(device bool *out, bool val)
Definition ops.h:88
+
bool simd_reduce(bool val)
Definition ops.h:62
+
static constexpr constant bool init
Definition ops.h:66
+
Definition ops.h:119
+
U operator()(U a, U b)
Definition ops.h:133
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:128
+
T simd_reduce(T val)
Definition ops.h:121
+
static constexpr constant U init
Definition ops.h:125
+
Definition ops.h:99
+
void atomic_update(device mlx_atomic< T > *out, T val, uint offset=0)
Definition ops.h:108
+
static constexpr constant U init
Definition ops.h:105
+
T simd_reduce(T val)
Definition ops.h:101
+
U operator()(U a, U b)
Definition ops.h:113
+
Definition atomic.h:26
+
Definition ops.h:12
+
bool4 b
Definition ops.h:13
+
unsigned int i
Definition ops.h:14
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h.html b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h.html new file mode 100644 index 000000000..2c79e0cea --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h.html @@ -0,0 +1,126 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
utils.h File Reference
+
+
+
#include <metal_atomic>
+#include <metal_simdgroup>
+#include "mlx/backend/metal/kernels/defines.h"
+#include "mlx/backend/metal/kernels/steel/utils.h"
+#include "mlx/backend/metal/kernels/utils.h"
+#include "mlx/backend/metal/kernels/reduction/ops.h"
+
+

Go to the source code of this file.

+ + + + +

+Variables

static constant constexpr const uint8_t simd_size = 32
 
+

Variable Documentation

+ +

◆ simd_size

+ +
+
+ + + + + +
+ + + + +
constant constexpr const uint8_t simd_size = 32
+
+staticconstexpr
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h_source.html b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h_source.html new file mode 100644 index 000000000..c20c0da98 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2reduction_2utils_8h_source.html @@ -0,0 +1,111 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/reduction/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_atomic>
+
6#include <metal_simdgroup>
+
7
+ + + +
11
+ +
13
+
14static constant constexpr const uint8_t simd_size = 32;
+ +
static constant constexpr const uint8_t simd_size
Definition utils.h:14
+ + + +
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h.html b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h.html new file mode 100644 index 000000000..2f5a4eaa4 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h.html @@ -0,0 +1,114 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/gemm/transforms.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
transforms.h File Reference
+
+
+ +

Go to the source code of this file.

+ + + + + + + + + + + + +

+Classes

struct  mlx::steel::TransformNone< OutT, InT >
 
struct  mlx::steel::TransformAdd< OutT, InT >
 
struct  mlx::steel::TransformAxpby< OutT, InT >
 
struct  mlx::steel::AccumHelper< T >
 
struct  mlx::steel::BlockSwizzle
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::steel
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h_source.html b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h_source.html new file mode 100644 index 000000000..f9ed6a5f8 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2gemm_2transforms_8h_source.html @@ -0,0 +1,192 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/gemm/transforms.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
transforms.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+
8// Transforms and Epilogues
+
10
+
11namespace mlx {
+
12namespace steel {
+
13
+
14template <typename OutT, typename InT>
+
+ +
+
16 static METAL_FUNC OutT apply(InT x) {
+
17 return static_cast<OutT>(x);
+
18 }
+
+
19
+
+
20 static METAL_FUNC OutT apply(InT x, OutT) {
+
21 return static_cast<OutT>(x);
+
22 }
+
+
23};
+
+
24
+
25template <typename OutT, typename InT>
+
+ +
27 TransformAdd(const float, const float) {}
+
28
+
+
29 static METAL_FUNC OutT apply(InT x, OutT c) {
+
30 return static_cast<OutT>(x) + c;
+
31 }
+
+
32};
+
+
33
+
34template <typename OutT, typename InT>
+
+ +
36 const float alpha;
+
37 const float beta;
+
38
+
+
39 TransformAxpby(const float alpha_, const float beta_)
+
40 : alpha(alpha_), beta(beta_) {}
+
+
41
+
+
42 METAL_FUNC OutT apply(InT x, OutT c) const {
+
43 return static_cast<OutT>(x * alpha + (beta * c));
+
44 }
+
+
45};
+
+
46
+
47template <typename T>
+
+ +
49 typedef float accum_type;
+
50};
+
+
51
+
+ +
53 static METAL_FUNC int2
+
+
54 swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
+
55 const int tid_x = (tid.x) >> swizzle_log;
+
56 const int tid_y =
+
57 ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
+
58 return int2(tid_x, tid_y);
+
59 }
+
+
60};
+
+
61
+
62} // namespace steel
+
63} // namespace mlx
+ +
Definition allocator.h:7
+
Definition transforms.h:48
+
float accum_type
Definition transforms.h:49
+
Definition transforms.h:52
+
static METAL_FUNC int2 swizzle(uint3 tid, const int swizzle_log)
Definition transforms.h:54
+
Definition transforms.h:26
+
static METAL_FUNC OutT apply(InT x, OutT c)
Definition transforms.h:29
+
TransformAdd(const float, const float)
Definition transforms.h:27
+
Definition transforms.h:35
+
const float beta
Definition transforms.h:37
+
METAL_FUNC OutT apply(InT x, OutT c) const
Definition transforms.h:42
+
const float alpha
Definition transforms.h:36
+
TransformAxpby(const float alpha_, const float beta_)
Definition transforms.h:39
+
Definition transforms.h:15
+
static METAL_FUNC OutT apply(InT x)
Definition transforms.h:16
+
static METAL_FUNC OutT apply(InT x, OutT)
Definition transforms.h:20
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h.html b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h.html new file mode 100644 index 000000000..31eb63250 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h.html @@ -0,0 +1,215 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
utils.h File Reference
+
+
+
#include <metal_stdlib>
+
+

Go to the source code of this file.

+ + + + + + +

+Macros

#define STEEL_CONST   static constant constexpr const
 
#define STEEL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
 
+ + + + + +

+Functions

METAL_FUNC ulong2 elem_to_loc_broadcast (uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
 
METAL_FUNC ulong3 elem_to_loc_broadcast (uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
 
+

Macro Definition Documentation

+ +

◆ STEEL_CONST

+ +
+
+ + + + +
#define STEEL_CONST   static constant constexpr const
+
+ +
+
+ +

◆ STEEL_PRAGMA_UNROLL

+ +
+
+ + + + +
#define STEEL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
+
+ +
+
+

Function Documentation

+ +

◆ elem_to_loc_broadcast() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC ulong3 elem_to_loc_broadcast (uint elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
constant const size_t * c_strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_broadcast() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC ulong2 elem_to_loc_broadcast (uint elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
int ndim )
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h_source.html b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h_source.html new file mode 100644 index 000000000..1621ac55e --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2steel_2utils_8h_source.html @@ -0,0 +1,142 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_stdlib>
+
6
+
7#define STEEL_CONST static constant constexpr const
+
8#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
+
9
+
+
10METAL_FUNC ulong2 elem_to_loc_broadcast(
+
11 uint elem,
+
12 constant const int* shape,
+
13 constant const size_t* a_strides,
+
14 constant const size_t* b_strides,
+
15 int ndim) {
+
16 ulong loc_a{0};
+
17 ulong loc_b{0};
+
18 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
19 int pos_in_dim = (elem % shape[i]);
+
20 elem /= shape[i];
+
21 loc_a += pos_in_dim * a_strides[i];
+
22 loc_b += pos_in_dim * b_strides[i];
+
23 }
+
24 return ulong2(loc_a, loc_b);
+
25}
+
+
26
+
+
27METAL_FUNC ulong3 elem_to_loc_broadcast(
+
28 uint elem,
+
29 constant const int* shape,
+
30 constant const size_t* a_strides,
+
31 constant const size_t* b_strides,
+
32 constant const size_t* c_strides,
+
33 int ndim) {
+
34 ulong loc_a{0};
+
35 ulong loc_b{0};
+
36 ulong loc_c{0};
+
37 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
38 int pos_in_dim = (elem % shape[i]);
+
39 elem /= shape[i];
+
40 loc_a += pos_in_dim * a_strides[i];
+
41 loc_b += pos_in_dim * b_strides[i];
+
42 loc_c += pos_in_dim * c_strides[i];
+
43 }
+
44 return ulong3(loc_a, loc_b, loc_c);
+
45}
+
+
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:10
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2utils_8h.html b/docs/build/html/backend_2metal_2kernels_2utils_8h.html new file mode 100644 index 000000000..bce4b00aa --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2utils_8h.html @@ -0,0 +1,862 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
utils.h File Reference
+
+
+
#include <metal_math>
+#include "mlx/backend/metal/kernels/bf16.h"
+#include "mlx/backend/metal/kernels/complex.h"
+
+

Go to the source code of this file.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Classes

struct  Limits< U >
 
struct  Limits< uint8_t >
 
struct  Limits< uint16_t >
 
struct  Limits< uint32_t >
 
struct  Limits< uint64_t >
 
struct  Limits< int8_t >
 
struct  Limits< int16_t >
 
struct  Limits< int32_t >
 
struct  Limits< int64_t >
 
struct  Limits< half >
 
struct  Limits< float >
 
struct  Limits< bfloat16_t >
 
struct  Limits< bool >
 
+ + + + + + + +

+Macros

#define instantiate_default_limit(type)
 
#define instantiate_float_limit(type)
 
#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint elem, device const int *shape, device const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc (uint3 elem, constant const int *shape, constant const stride_t *strides, int ndim)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_1 (uint elem, constant const stride_t &stride)
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_2 (uint2 elem, constant const stride_t strides[2])
 
template<typename stride_t >
METAL_FUNC stride_t elem_to_loc_3 (uint3 elem, constant const stride_t strides[3])
 
template<int NDIM>
METAL_FUNC size_t elem_to_loc_nd (uint elem, device const int *shape, device const size_t *strides)
 
template<int NDIM>
METAL_FUNC size_t elem_to_loc_nd (uint3 elem, constant const int shape[NDIM], constant const size_t strides[NDIM])
 
template<int NDIM>
METAL_FUNC int64_t elem_to_loc_nd (uint elem, constant const int shape[NDIM], constant const int64_t strides[NDIM])
 
template<int NDIM>
METAL_FUNC int64_t elem_to_loc_nd (uint3 elem, constant const int shape[NDIM], constant const int64_t strides[NDIM])
 
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
 
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
 
template<int NDIM>
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem, constant const int shape[NDIM], constant const size_t a_strides[NDIM], constant const size_t b_strides[NDIM])
 
template<int NDIM>
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem, constant const int shape[NDIM], constant const size_t a_strides[NDIM], constant const size_t b_strides[NDIM], constant const size_t c_strides[NDIM])
 
size_t ceildiv (size_t N, size_t M)
 Compute ceil((float)N/(float)M)
 
float log1p (float x)
 
bfloat16_t log1p (bfloat16_t x)
 
uint64_t simd_shuffle_down (uint64_t data, uint16_t delta)
 
int64_t simd_shuffle_down (int64_t data, uint16_t delta)
 
bool simd_shuffle_down (bool data, uint16_t delta)
 
+

Macro Definition Documentation

+ +

◆ instantiate_default_limit

+ +
+
+ + + + + + + +
#define instantiate_default_limit( type)
+
+Value:
template <> \
+
struct Limits<type> { \
+
static constexpr constant type max = metal::numeric_limits<type>::max(); \
+
static constexpr constant type min = metal::numeric_limits<type>::min(); \
+
static constexpr constant type finite_max = \
+
metal::numeric_limits<type>::max(); \
+
static constexpr constant type finite_min = \
+
metal::numeric_limits<type>::min(); \
+
};
+
Definition utils.h:14
+
static const constant U max
Definition utils.h:15
+
static const constant U finite_max
Definition utils.h:17
+
static const constant U min
Definition utils.h:16
+
static const constant U finite_min
Definition utils.h:18
+
+
+
+ +

◆ instantiate_float_limit

+ +
+
+ + + + + + + +
#define instantiate_float_limit( type)
+
+Value:
template <> \
+
struct Limits<type> { \
+
static constexpr constant type max = \
+
metal::numeric_limits<type>::infinity(); \
+
static constexpr constant type min = \
+
-metal::numeric_limits<type>::infinity(); \
+
static constexpr constant type finite_max = \
+
metal::numeric_limits<type>::max(); \
+
static constexpr constant type finite_min = \
+
-metal::numeric_limits<type>::max(); \
+
};
+
+
+
+ +

◆ MLX_MTL_PRAGMA_UNROLL

+ +
+
+ + + + +
#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
+
+ +
+
+

Function Documentation

+ +

◆ ceildiv()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
size_t ceildiv (size_t N,
size_t M )
+
+inline
+
+ +

Compute ceil((float)N/(float)M)

+ +
+
+ +

◆ elem_to_loc() [1/3]

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc (uint elem,
constant const int * shape,
constant const stride_t * strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc() [2/3]

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc (uint elem,
device const int * shape,
device const stride_t * strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc() [3/3]

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc (uint3 elem,
constant const int * shape,
constant const stride_t * strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_1()

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc_1 (uint elem,
constant const stride_t & stride )
+
+ +
+
+ +

◆ elem_to_loc_2()

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc_2 (uint2 elem,
constant const stride_t strides[2] )
+
+ +
+
+ +

◆ elem_to_loc_2_nd() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_2_nd() [2/2]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint2 elem_to_loc_2_nd (uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_3()

+ +
+
+
+template<typename stride_t >
+ + + + + + + + + + + +
METAL_FUNC stride_t elem_to_loc_3 (uint3 elem,
constant const stride_t strides[3] )
+
+ +
+
+ +

◆ elem_to_loc_3_nd() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem,
constant const int * shape,
constant const size_t * a_strides,
constant const size_t * b_strides,
constant const size_t * c_strides,
int ndim )
+
+ +
+
+ +

◆ elem_to_loc_3_nd() [2/2]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
METAL_FUNC uint3 elem_to_loc_3_nd (uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_nd() [1/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC int64_t elem_to_loc_nd (uint elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_nd() [2/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC size_t elem_to_loc_nd (uint elem,
device const int * shape,
device const size_t * strides )
+
+ +
+
+ +

◆ elem_to_loc_nd() [3/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC int64_t elem_to_loc_nd (uint3 elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM] )
+
+ +
+
+ +

◆ elem_to_loc_nd() [4/4]

+ +
+
+
+template<int NDIM>
+ + + + + + + + + + + + + + + + +
METAL_FUNC size_t elem_to_loc_nd (uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM] )
+
+ +
+
+ +

◆ log1p() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
bfloat16_t log1p (bfloat16_t x)
+
+inline
+
+ +
+
+ +

◆ log1p() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
float log1p (float x)
+
+inline
+
+ +
+
+ +

◆ simd_shuffle_down() [1/3]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
bool simd_shuffle_down (bool data,
uint16_t delta )
+
+inline
+
+ +
+
+ +

◆ simd_shuffle_down() [2/3]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
int64_t simd_shuffle_down (int64_t data,
uint16_t delta )
+
+inline
+
+ +
+
+ +

◆ simd_shuffle_down() [3/3]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
uint64_t simd_shuffle_down (uint64_t data,
uint16_t delta )
+
+inline
+
+ +
+
+
+ + + + diff --git a/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html b/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html new file mode 100644 index 000000000..32aeb0750 --- /dev/null +++ b/docs/build/html/backend_2metal_2kernels_2utils_8h_source.html @@ -0,0 +1,488 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include <metal_math>
+ + +
8
+
10// Type limits utils
+
12
+
13template <typename U>
+
+
14struct Limits {
+
15 static const constant U max = metal::numeric_limits<U>::max();
+
16 static const constant U min = metal::numeric_limits<U>::min();
+
17 static const constant U finite_max = metal::numeric_limits<U>::max();
+
18 static const constant U finite_min = metal::numeric_limits<U>::min();
+
19};
+
+
20
+
+
21#define instantiate_default_limit(type) \
+
22 template <> \
+
23 struct Limits<type> { \
+
24 static constexpr constant type max = metal::numeric_limits<type>::max(); \
+
25 static constexpr constant type min = metal::numeric_limits<type>::min(); \
+
26 static constexpr constant type finite_max = \
+
27 metal::numeric_limits<type>::max(); \
+
28 static constexpr constant type finite_min = \
+
29 metal::numeric_limits<type>::min(); \
+
30 };
+
+
31
+ + + + + + + + +
40
+
+
41#define instantiate_float_limit(type) \
+
42 template <> \
+
43 struct Limits<type> { \
+
44 static constexpr constant type max = \
+
45 metal::numeric_limits<type>::infinity(); \
+
46 static constexpr constant type min = \
+
47 -metal::numeric_limits<type>::infinity(); \
+
48 static constexpr constant type finite_max = \
+
49 metal::numeric_limits<type>::max(); \
+
50 static constexpr constant type finite_min = \
+
51 -metal::numeric_limits<type>::max(); \
+
52 };
+
+
53
+ + + +
57
+
58template <>
+
+
59struct Limits<bool> {
+
60 static constexpr constant bool max = true;
+
61 static constexpr constant bool min = false;
+
62};
+
+
63
+
65// Indexing utils
+
67
+
68#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
+
69
+
71// Single Array with generic dims
+
72
+
73template <typename stride_t>
+
+
74METAL_FUNC stride_t elem_to_loc(
+
75 uint elem,
+
76 device const int* shape,
+
77 device const stride_t* strides,
+
78 int ndim) {
+
79 stride_t loc = 0;
+
80 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
81 loc += (elem % shape[i]) * strides[i];
+
82 elem /= shape[i];
+
83 }
+
84 return loc;
+
85}
+
+
86
+
87template <typename stride_t>
+
+
88METAL_FUNC stride_t elem_to_loc(
+
89 uint elem,
+
90 constant const int* shape,
+
91 constant const stride_t* strides,
+
92 int ndim) {
+
93 stride_t loc = 0;
+
94 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
+
95 loc += (elem % shape[i]) * strides[i];
+
96 elem /= shape[i];
+
97 }
+
98 return loc;
+
99}
+
+
100
+
101// Non templated version to handle arbitrary dims
+
102template <typename stride_t>
+
+
103METAL_FUNC stride_t elem_to_loc(
+
104 uint3 elem,
+
105 constant const int* shape,
+
106 constant const stride_t* strides,
+
107 int ndim) {
+
108 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
+
109 for (int d = ndim - 3; d >= 0; --d) {
+
110 loc += (elem.z % shape[d]) * strides[d];
+
111 elem.z /= shape[d];
+
112 }
+
113 return loc;
+
114}
+
+
115
+
117// Single Array with fixed N dims
+
118
+
119template <typename stride_t>
+
+
120METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
+
121 return elem * stride;
+
122}
+
+
123
+
124template <typename stride_t>
+
125METAL_FUNC stride_t
+
+
126elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
+
127 return elem.x * strides[1] + elem.y * strides[0];
+
128}
+
+
129
+
130template <typename stride_t>
+
131METAL_FUNC stride_t
+
+
132elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
+
133 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
+
134}
+
+
135
+
136template <int NDIM>
+
+
137METAL_FUNC size_t elem_to_loc_nd(
+
138 uint elem,
+
139 device const int* shape,
+
140 device const size_t* strides) {
+
141 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
+
142
+ +
144 for (int d = NDIM - 2; d >= 0; --d) {
+
145 elem /= shape[d + 1];
+
146 loc += (elem % shape[d]) * strides[d];
+
147 }
+
148
+
149 return loc;
+
150}
+
+
151
+
152template <int NDIM>
+
+
153METAL_FUNC size_t elem_to_loc_nd(
+
154 uint3 elem,
+
155 constant const int shape[NDIM],
+
156 constant const size_t strides[NDIM]) {
+
157 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
+
158 for (int d = NDIM - 3; d >= 0; --d) {
+
159 loc += (elem.z % shape[d]) * strides[d];
+
160 elem.z /= shape[d];
+
161 }
+
162 return loc;
+
163}
+
+
164
+
165template <int NDIM>
+
+
166METAL_FUNC int64_t elem_to_loc_nd(
+
167 uint elem,
+
168 constant const int shape[NDIM],
+
169 constant const int64_t strides[NDIM]) {
+
170 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
+
171
+ +
173 for (int d = NDIM - 2; d >= 0; --d) {
+
174 elem /= shape[d + 1];
+
175 loc += (elem % shape[d]) * strides[d];
+
176 }
+
177
+
178 return loc;
+
179}
+
+
180
+
181template <int NDIM>
+
+
182METAL_FUNC int64_t elem_to_loc_nd(
+
183 uint3 elem,
+
184 constant const int shape[NDIM],
+
185 constant const int64_t strides[NDIM]) {
+
186 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
+
187 for (int d = NDIM - 3; d >= 0; --d) {
+
188 loc += (elem.z % shape[d]) * strides[d];
+
189 elem.z /= shape[d];
+
190 }
+
191 return loc;
+
192}
+
+
193
+
195// Multiple Arrays with generic dims
+
196
+
+
197METAL_FUNC uint2 elem_to_loc_2_nd(
+
198 uint3 elem,
+
199 constant const int* shape,
+
200 constant const size_t* a_strides,
+
201 constant const size_t* b_strides,
+
202 int ndim) {
+
203 uint2 loc = {
+
204 static_cast<uint>(
+
205 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
+
206 static_cast<uint>(
+
207 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
+
208 for (int d = ndim - 3; d >= 0; --d) {
+
209 uint l = elem.z % shape[d];
+
210 loc.x += l * a_strides[d];
+
211 loc.y += l * b_strides[d];
+
212 elem.z /= shape[d];
+
213 }
+
214 return loc;
+
215}
+
+
216
+
+
217METAL_FUNC uint3 elem_to_loc_3_nd(
+
218 uint3 elem,
+
219 constant const int* shape,
+
220 constant const size_t* a_strides,
+
221 constant const size_t* b_strides,
+
222 constant const size_t* c_strides,
+
223 int ndim) {
+
224 uint3 loc = {
+
225 static_cast<uint>(
+
226 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
+
227 static_cast<uint>(
+
228 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
+
229 static_cast<uint>(
+
230 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
+
231 for (int d = ndim - 3; d >= 0; --d) {
+
232 uint l = elem.z % shape[d];
+
233 loc.x += l * a_strides[d];
+
234 loc.y += l * b_strides[d];
+
235 loc.z += l * c_strides[d];
+
236 elem.z /= shape[d];
+
237 }
+
238 return loc;
+
239}
+
+
240
+
242// Multiple Arrays with fixed N dims
+
243
+
244template <int NDIM>
+
+
245METAL_FUNC uint2 elem_to_loc_2_nd(
+
246 uint3 elem,
+
247 constant const int shape[NDIM],
+
248 constant const size_t a_strides[NDIM],
+
249 constant const size_t b_strides[NDIM]) {
+
250 uint2 loc = {
+
251 static_cast<uint>(
+
252 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
+
253 static_cast<uint>(
+
254 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
+
255 for (int d = NDIM - 3; d >= 0; --d) {
+
256 uint l = elem.z % shape[d];
+
257 loc.x += l * a_strides[d];
+
258 loc.y += l * b_strides[d];
+
259 elem.z /= shape[d];
+
260 }
+
261 return loc;
+
262}
+
+
263
+
264template <int NDIM>
+
+
265METAL_FUNC uint3 elem_to_loc_3_nd(
+
266 uint3 elem,
+
267 constant const int shape[NDIM],
+
268 constant const size_t a_strides[NDIM],
+
269 constant const size_t b_strides[NDIM],
+
270 constant const size_t c_strides[NDIM]) {
+
271 uint3 loc = {
+
272 static_cast<uint>(
+
273 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
+
274 static_cast<uint>(
+
275 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
+
276 static_cast<uint>(
+
277 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
+
278 for (int d = NDIM - 3; d >= 0; --d) {
+
279 uint l = elem.z % shape[d];
+
280 loc.x += l * a_strides[d];
+
281 loc.y += l * b_strides[d];
+
282 loc.z += l * c_strides[d];
+
283 elem.z /= shape[d];
+
284 }
+
285 return loc;
+
286}
+
+
287
+
289// Calculation utils
+
291
+
+
293inline size_t ceildiv(size_t N, size_t M) {
+
294 return (N + M - 1) / M;
+
295}
+
+
296
+
297// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
+
+
298inline float log1p(float x) {
+
299 float xp1 = 1.0f + x;
+
300 if (xp1 == Limits<float>::max) {
+
301 return Limits<float>::max;
+
302 }
+
303 if (xp1 == 1.0f) {
+
304 return x;
+
305 }
+
306
+
307 return x * (metal::log(xp1) / (xp1 - 1.0f));
+
308}
+
+
309
+
+ +
311 float xp1 = 1.0f + static_cast<float>(x);
+
312 if (xp1 == Limits<float>::max) {
+ +
314 }
+
315 if (xp1 == 1.0f) {
+
316 return x;
+
317 }
+
318
+
319 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
+
320}
+
+
321
+
323// SIMD shuffle ops
+
325
+
+
326inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
+
327 return as_type<uint64_t>(
+
328 metal::simd_shuffle_down(as_type<uint2>(data), delta));
+
329}
+
+
330
+
+
331inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
+
332 return as_type<int64_t>(
+
333 metal::simd_shuffle_down(as_type<uint2>(data), delta));
+
334}
+
+
335
+
+
336inline bool simd_shuffle_down(bool data, uint16_t delta) {
+
337 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
+
338}
+
+ +
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
+ +
#define MLX_MTL_PRAGMA_UNROLL
Definition utils.h:68
+
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:120
+
#define instantiate_float_limit(type)
Definition utils.h:41
+
float log1p(float x)
Definition utils.h:298
+
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:132
+
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:74
+
METAL_FUNC uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:197
+
size_t ceildiv(size_t N, size_t M)
Compute ceil((float)N/(float)M)
Definition utils.h:293
+
METAL_FUNC uint3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:217
+
METAL_FUNC size_t elem_to_loc_nd(uint elem, device const int *shape, device const size_t *strides)
Definition utils.h:137
+
#define instantiate_default_limit(type)
Definition utils.h:21
+
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:126
+
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:234
+
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
+
Definition bf16.h:54
+
Definition utils.h:14
+
static const constant U max
Definition utils.h:15
+
static const constant U finite_max
Definition utils.h:17
+
static const constant U min
Definition utils.h:16
+
static const constant U finite_min
Definition utils.h:18
+
+ + + + diff --git a/docs/build/html/backend_2metal_2utils_8h.html b/docs/build/html/backend_2metal_2utils_8h.html new file mode 100644 index 000000000..9a2da99d3 --- /dev/null +++ b/docs/build/html/backend_2metal_2utils_8h.html @@ -0,0 +1,102 @@ + + + + + + + +MLX: mlx/backend/metal/utils.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
utils.h File Reference
+
+
+
#include "mlx/array.h"
+#include "mlx/backend/metal/device.h"
+#include "mlx/primitives.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/backend_2metal_2utils_8h_source.html b/docs/build/html/backend_2metal_2utils_8h_source.html new file mode 100644 index 000000000..5bfd2bb38 --- /dev/null +++ b/docs/build/html/backend_2metal_2utils_8h_source.html @@ -0,0 +1,249 @@ + + + + + + + +MLX: mlx/backend/metal/utils.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
utils.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/array.h"
+ +
7#include "mlx/primitives.h"
+
8
+
9namespace mlx::core {
+
10
+
11namespace {
+
12
+
13using metal::CommandEncoder;
+
14
+
15template <typename T>
+
16inline void set_vector_bytes(
+
17 CommandEncoder& enc,
+
18 const std::vector<T>& vec,
+
19 size_t nelems,
+
20 int idx) {
+
21 enc->setBytes(vec.data(), nelems * sizeof(T), idx);
+
22}
+
23
+
24template <typename T>
+
25inline void
+
26set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
+
27 return set_vector_bytes(enc, vec, vec.size(), idx);
+
28}
+
29
+
30std::string type_to_name(const array& a) {
+
31 std::string tname;
+
32 switch (a.dtype()) {
+
33 case bool_:
+
34 tname = "bool_";
+
35 break;
+
36 case uint8:
+
37 tname = "uint8";
+
38 break;
+
39 case uint16:
+
40 tname = "uint16";
+
41 break;
+
42 case uint32:
+
43 tname = "uint32";
+
44 break;
+
45 case uint64:
+
46 tname = "uint64";
+
47 break;
+
48 case int8:
+
49 tname = "int8";
+
50 break;
+
51 case int16:
+
52 tname = "int16";
+
53 break;
+
54 case int32:
+
55 tname = "int32";
+
56 break;
+
57 case int64:
+
58 tname = "int64";
+
59 break;
+
60 case float16:
+
61 tname = "float16";
+
62 break;
+
63 case float32:
+
64 tname = "float32";
+
65 break;
+
66 case bfloat16:
+
67 tname = "bfloat16";
+
68 break;
+
69 case complex64:
+
70 tname = "complex64";
+
71 break;
+
72 }
+
73 return tname;
+
74}
+
75
+
76MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
+
77 int pows[3] = {0, 0, 0};
+
78 int sum = 0;
+
79 while (true) {
+
80 int presum = sum;
+
81 // Check all the pows
+
82 if (dim0 >= (1 << (pows[0] + 1))) {
+
83 pows[0]++;
+
84 sum++;
+
85 }
+
86 if (sum == 10) {
+
87 break;
+
88 }
+
89 if (dim1 >= (1 << (pows[1] + 1))) {
+
90 pows[1]++;
+
91 sum++;
+
92 }
+
93 if (sum == 10) {
+
94 break;
+
95 }
+
96 if (dim2 >= (1 << (pows[2] + 1))) {
+
97 pows[2]++;
+
98 sum++;
+
99 }
+
100 if (sum == presum || sum == 10) {
+
101 break;
+
102 }
+
103 }
+
104 return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
+
105}
+
106
+
107inline NS::String* make_string(std::ostringstream& os) {
+
108 std::string string = os.str();
+
109 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
+
110}
+
111
+
112inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
+
113#ifdef MLX_METAL_DEBUG
+
114 std::ostringstream label;
+
115 label << "Stream " << index;
+
116 queue->setLabel(make_string(label));
+
117#endif
+
118}
+
119
+
120inline void debug_set_primitive_buffer_label(
+
121 MTL::CommandBuffer* command_buffer,
+
122 Primitive& primitive) {
+
123#ifdef MLX_METAL_DEBUG
+
124 std::ostringstream label;
+
125 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
+
126 label << cbuf_label->utf8String();
+
127 }
+
128 primitive.print(label);
+
129 command_buffer->setLabel(make_string(label));
+
130#endif
+
131}
+
132
+
133bool is_power_of_2(int n) {
+
134 return ((n & (n - 1)) == 0) && n != 0;
+
135}
+
136
+
137} // namespace
+
138
+
139} // namespace mlx::core
+ + +
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+ +
+ + + + diff --git a/docs/build/html/bc_s.png b/docs/build/html/bc_s.png new file mode 100644 index 0000000000000000000000000000000000000000..224b29aa9847d5a4b3902efd602b7ddf7d33e6c2 GIT binary patch literal 676 zcmV;V0$crwP)y__>=_9%My z{n931IS})GlGUF8K#6VIbs%684A^L3@%PlP2>_sk`UWPq@f;rU*V%rPy_ekbhXT&s z(GN{DxFv}*vZp`F>S!r||M`I*nOwwKX+BC~3P5N3-)Y{65c;ywYiAh-1*hZcToLHK ztpl1xomJ+Yb}K(cfbJr2=GNOnT!UFA7Vy~fBz8?J>XHsbZoDad^8PxfSa0GDgENZS zuLCEqzb*xWX2CG*b&5IiO#NzrW*;`VC9455M`o1NBh+(k8~`XCEEoC1Ybwf;vr4K3 zg|EB<07?SOqHp9DhLpS&bzgo70I+ghB_#)K7H%AMU3v}xuyQq9&Bm~++VYhF09a+U zl7>n7Jjm$K#b*FONz~fj;I->Bf;ule1prFN9FovcDGBkpg>)O*-}eLnC{6oZHZ$o% zXKW$;0_{8hxHQ>l;_*HATI(`7t#^{$(zLe}h*mqwOc*nRY9=?Sx4OOeVIfI|0V(V2 zBrW#G7Ss9wvzr@>H*`r>zE z+e8bOBgqIgldUJlG(YUDviMB`9+DH8n-s9SXRLyJHO1!=wY^79WYZMTa(wiZ!zP66 zA~!21vmF3H2{ngD;+`6j#~6j;$*f*G_2ZD1E;9(yaw7d-QnSCpK(cR1zU3qU0000< KMNUMnLSTYoA~SLT literal 0 HcmV?d00001 diff --git a/docs/build/html/bc_sd.png b/docs/build/html/bc_sd.png new file mode 100644 index 0000000000000000000000000000000000000000..31ca888dc71049713b35c351933a8d0f36180bf1 GIT binary patch literal 635 zcmV->0)+jEP)Jwi0r1~gdSq#w{Bu1q z`craw(p2!hu$4C_$Oc3X(sI6e=9QSTwPt{G) z=htT&^~&c~L2~e{r5_5SYe7#Is-$ln>~Kd%$F#tC65?{LvQ}8O`A~RBB0N~`2M+waajO;5>3B&-viHGJeEK2TQOiPRa zfDKyqwMc4wfaEh4jt>H`nW_Zidwk@Bowp`}(VUaj-pSI(-1L>FJVsX}Yl9~JsqgsZ zUD9(rMwf23Gez6KPa|wwInZodP-2}9@fK0Ga_9{8SOjU&4l`pH4@qlQp83>>HT$xW zER^U>)MyV%t(Lu=`d=Y?{k1@}&r7ZGkFQ%z%N+sE9BtYjovzxyxCPxN6&@wLK{soQ zSmkj$aLI}miuE^p@~4}mg9OjDfGEkgY4~^XzLRUBB*O{+&vq<3v(E%+k_i%=`~j%{ Vj14gnt9}3g002ovPDHLkV1n!oC4m3{ literal 0 HcmV?d00001 diff --git a/docs/build/html/bf16__math_8h.html b/docs/build/html/bf16__math_8h.html new file mode 100644 index 000000000..78e4d59e8 --- /dev/null +++ b/docs/build/html/bf16__math_8h.html @@ -0,0 +1,594 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16_math.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
bf16_math.h File Reference
+
+
+ +

Go to the source code of this file.

+ + + + + + + + +

+Namespaces

namespace  metal
 
namespace  metal::fast
 
namespace  metal::precise
 
+ + + + + + + + + + + +

+Macros

#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
 
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
 
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
 
#define bfloat16_to_uint16(x)   x.bits_
 
#define uint16_to_bfloat16(x)   _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

METAL_FUNC bfloat16_t metal::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::fast::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::fast::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::fast::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::precise::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::precise::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::precise::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::simd_broadcast (bfloat16_t data, ushort broadcast_lane_id)
 
METAL_FUNC bfloat16_t metal::simd_shuffle (bfloat16_t data, ushort simd_lane_id)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_down (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_down (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_up (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_up (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_xor (bfloat16_t data, ushort mask)
 
METAL_FUNC bfloat16_t metal::simd_max (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_min (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_xor (bfloat16_t data)
 
+

Macro Definition Documentation

+ +

◆ bfloat16_to_uint16

+ +
+
+ + + + + + + +
#define bfloat16_to_uint16( x)   x.bits_
+
+ +
+
+ +

◆ instantiate_metal_math_funcs

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
#define instantiate_metal_math_funcs( itype,
otype,
ctype,
mfast )
+
+ +
+
+ +

◆ instantiate_metal_simd_comm_funcs

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
#define instantiate_metal_simd_comm_funcs( itype,
otype,
ctype,
itype_to_ctype,
ctype_to_otype )
+
+ +
+
+ +

◆ instantiate_metal_simd_reduction_funcs

+ +
+
+ + + + + + + + + + + + + + + + +
#define instantiate_metal_simd_reduction_funcs( itype,
otype,
ctype )
+
+ +
+
+ +

◆ uint16_to_bfloat16

+ +
+
+ + + + + + + +
#define uint16_to_bfloat16( x)   _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
+
+ +
+
+
+ + + + diff --git a/docs/build/html/bf16__math_8h_source.html b/docs/build/html/bf16__math_8h_source.html new file mode 100644 index 000000000..217403ab0 --- /dev/null +++ b/docs/build/html/bf16__math_8h_source.html @@ -0,0 +1,498 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/bf16_math.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
bf16_math.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+
8// Metal math for bfloat16
+
10
+
11/*
+
12
+
13Following the Metal Shading Language Specification (Metal 3.1)
+
14
+
15"bfloat is an extended itypeing point type that only allows implicit conversion
+
16 to a type of greater itypeing point rank. While bfloat can be implicitly
+
17 converted to itype, it cannot be implicitly converted to half, and neither
+
18 itype nor half can be implicitly converted to bfloat."
+
19
+
20Further, as far as I can tell, the stdlib math/simd functions are not defined
+
21for bfloat and calling with an argument of type bfloat will result in that
+
22argument getting implicitly converted to itype which then returns an output
+
23that is (likely) a itype which cannot be implicitly converted into a bfloat
+
24
+
25This leads to situations where
+
26bfloat a = 5.0bf;
+
27bfloat b = metal::abs(a); // this will throw an error since abs return itype
+
28bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
+
29
+
30For the moment, I will be adding overloaded instantiations of the math
+
31functions to accordingly automatically handle the casting
+
32
+
33*/
+
34
+
+
35#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
+
36 \
+
37 METAL_FUNC otype abs(itype x) { \
+
38 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
+
39 } \
+
40 METAL_FUNC otype acos(itype x) { \
+
41 return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
+
42 } \
+
43 METAL_FUNC otype acosh(itype x) { \
+
44 return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
+
45 } \
+
46 METAL_FUNC otype asin(itype x) { \
+
47 return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
+
48 } \
+
49 METAL_FUNC otype asinh(itype x) { \
+
50 return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
+
51 } \
+
52 METAL_FUNC otype atan(itype y_over_x) { \
+
53 return static_cast<otype>( \
+
54 __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
+
55 } \
+
56 METAL_FUNC otype atan2(itype y, itype x) { \
+
57 return static_cast<otype>( \
+
58 __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
+
59 } \
+
60 METAL_FUNC otype atanh(itype x) { \
+
61 return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
+
62 } \
+
63 METAL_FUNC otype ceil(itype x) { \
+
64 return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
+
65 } \
+
66 METAL_FUNC otype cos(itype x) { \
+
67 return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
+
68 } \
+
69 METAL_FUNC otype cosh(itype x) { \
+
70 return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
+
71 } \
+
72 METAL_FUNC otype cospi(itype x) { \
+
73 return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
+
74 } \
+
75 METAL_FUNC otype divide(itype x, itype y) { \
+
76 return static_cast<otype>( \
+
77 __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
78 } \
+
79 METAL_FUNC otype exp(itype x) { \
+
80 return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
+
81 } \
+
82 METAL_FUNC otype exp10(itype x) { \
+
83 return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
+
84 } \
+
85 METAL_FUNC otype exp2(itype x) { \
+
86 return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
+
87 } \
+
88 METAL_FUNC otype fabs(itype x) { \
+
89 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
+
90 } \
+
91 METAL_FUNC otype fdim(itype x, itype y) { \
+
92 ctype t = static_cast<ctype>(x - y); \
+
93 return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
+
94 } \
+
95 METAL_FUNC otype floor(itype x) { \
+
96 return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
+
97 } \
+
98 METAL_FUNC otype fma(itype x, itype y, itype z) { \
+
99 return static_cast<otype>(__metal_fma( \
+
100 static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
+
101 } \
+
102 METAL_FUNC otype fmax(itype x, itype y) { \
+
103 return static_cast<otype>( \
+
104 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
105 } \
+
106 METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
+
107 return static_cast<otype>(__metal_fmax3( \
+
108 static_cast<ctype>(x), \
+
109 static_cast<ctype>(y), \
+
110 static_cast<ctype>(z), \
+
111 mfast)); \
+
112 } \
+
113 METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
+
114 return static_cast<otype>(__metal_fmedian3( \
+
115 static_cast<ctype>(x), \
+
116 static_cast<ctype>(y), \
+
117 static_cast<ctype>(z), \
+
118 mfast)); \
+
119 } \
+
120 METAL_FUNC otype fmin(itype x, itype y) { \
+
121 return static_cast<otype>( \
+
122 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
123 } \
+
124 METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
+
125 return static_cast<otype>(__metal_fmin3( \
+
126 static_cast<ctype>(x), \
+
127 static_cast<ctype>(y), \
+
128 static_cast<ctype>(z), \
+
129 mfast)); \
+
130 } \
+
131 METAL_FUNC otype fmod(itype x, itype y) { \
+
132 return static_cast<otype>( \
+
133 __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
134 } \
+
135 METAL_FUNC otype fract(itype x) { \
+
136 return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
+
137 } \
+
138 METAL_FUNC otype frexp(itype x, thread int& exp) { \
+
139 return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
+
140 } \
+
141 METAL_FUNC otype ldexp(itype x, int k) { \
+
142 return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
+
143 } \
+
144 METAL_FUNC otype log(itype x) { \
+
145 return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
+
146 } \
+
147 METAL_FUNC otype log10(itype x) { \
+
148 return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
+
149 } \
+
150 METAL_FUNC otype log2(itype x) { \
+
151 return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
+
152 } \
+
153 METAL_FUNC otype max(itype x, itype y) { \
+
154 return static_cast<otype>( \
+
155 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
156 } \
+
157 METAL_FUNC otype max3(itype x, itype y, itype z) { \
+
158 return static_cast<otype>(__metal_fmax3( \
+
159 static_cast<ctype>(x), \
+
160 static_cast<ctype>(y), \
+
161 static_cast<ctype>(z), \
+
162 mfast)); \
+
163 } \
+
164 METAL_FUNC otype median3(itype x, itype y, itype z) { \
+
165 return static_cast<otype>(__metal_fmedian3( \
+
166 static_cast<ctype>(x), \
+
167 static_cast<ctype>(y), \
+
168 static_cast<ctype>(z), \
+
169 mfast)); \
+
170 } \
+
171 METAL_FUNC otype min(itype x, itype y) { \
+
172 return static_cast<otype>( \
+
173 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
174 } \
+
175 METAL_FUNC otype min3(itype x, itype y, itype z) { \
+
176 return static_cast<otype>(__metal_fmin3( \
+
177 static_cast<ctype>(x), \
+
178 static_cast<ctype>(y), \
+
179 static_cast<ctype>(z), \
+
180 mfast)); \
+
181 } \
+
182 METAL_FUNC otype nextafter(itype x, itype y) { \
+
183 return static_cast<otype>( \
+
184 __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
+
185 } \
+
186 METAL_FUNC otype pow(itype x, itype y) { \
+
187 return static_cast<otype>( \
+
188 __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
189 } \
+
190 METAL_FUNC otype powr(itype x, itype y) { \
+
191 return static_cast<otype>( \
+
192 __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
+
193 } \
+
194 METAL_FUNC otype rint(itype x) { \
+
195 return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
+
196 } \
+
197 METAL_FUNC otype round(itype x) { \
+
198 return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
+
199 } \
+
200 METAL_FUNC otype rsqrt(itype x) { \
+
201 return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
+
202 } \
+
203 METAL_FUNC otype sin(itype x) { \
+
204 return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
+
205 } \
+
206 METAL_FUNC otype sinh(itype x) { \
+
207 return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
+
208 } \
+
209 METAL_FUNC otype sinpi(itype x) { \
+
210 return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
+
211 } \
+
212 METAL_FUNC otype sqrt(itype x) { \
+
213 return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
+
214 } \
+
215 METAL_FUNC otype tan(itype x) { \
+
216 return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
+
217 } \
+
218 METAL_FUNC otype tanh(itype x) { \
+
219 return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
+
220 } \
+
221 METAL_FUNC otype tanpi(itype x) { \
+
222 return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
+
223 } \
+
224 METAL_FUNC otype trunc(itype x) { \
+
225 return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
+
226 }
+
+
227
+
228namespace metal {
+
229
+ + + +
233 float,
+
234 __METAL_MAYBE_FAST_MATH__);
+
235
+
+
236namespace fast {
+
237
+ + + +
241 float,
+
242 __METAL_FAST_MATH__);
+
243
+
244} // namespace fast
+
+
245
+
+
246namespace precise {
+
247
+ + + +
251 float,
+
252 __METAL_PRECISE_MATH__);
+
253
+
254} // namespace precise
+
+
255
+
256} // namespace metal
+
257
+
259// Metal simd for bfloat16
+
261
+
262#define instantiate_metal_simd_comm_funcs( \
+
263 itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
+
264 \
+
265 METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
+
266 return ctype_to_otype( \
+
267 __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
+
268 } \
+
269 \
+
270 METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
+
271 return ctype_to_otype( \
+
272 __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
+
273 } \
+
274 \
+
275 METAL_FUNC otype simd_shuffle_and_fill_down( \
+
276 itype data, itype filling_data, ushort delta, ushort modulo) { \
+
277 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
+
278 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
+
279 } \
+
280 \
+
281 METAL_FUNC otype simd_shuffle_and_fill_down( \
+
282 itype data, itype filling_data, ushort delta) { \
+
283 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
+
284 itype_to_ctype(data), \
+
285 itype_to_ctype(filling_data), \
+
286 delta, \
+
287 __metal_get_simdgroup_size(ushort()))); \
+
288 } \
+
289 \
+
290 METAL_FUNC otype simd_shuffle_and_fill_up( \
+
291 itype data, itype filling_data, ushort delta, ushort modulo) { \
+
292 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
+
293 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
+
294 } \
+
295 \
+
296 METAL_FUNC otype simd_shuffle_and_fill_up( \
+
297 itype data, itype filling_data, ushort delta) { \
+
298 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
+
299 itype_to_ctype(data), \
+
300 itype_to_ctype(filling_data), \
+
301 delta, \
+
302 __metal_get_simdgroup_size(ushort()))); \
+
303 } \
+
304 \
+
305 METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
+
306 return ctype_to_otype( \
+
307 __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
+
308 } \
+
309 \
+
310 METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
+
311 return ctype_to_otype( \
+
312 __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
+
313 } \
+
314 \
+
315 METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
+
316 return ctype_to_otype( \
+
317 __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
+
318 } \
+
319 \
+
320 METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
+
321 return ctype_to_otype( \
+
322 __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
+
323 } \
+
324 \
+
325 METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
+
326 return ctype_to_otype( \
+
327 __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
+
328 }
+
329
+
+
330#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
+
331 \
+
332 METAL_FUNC otype simd_max(itype data) { \
+
333 return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
+
334 } \
+
335 \
+
336 METAL_FUNC otype simd_min(itype data) { \
+
337 return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
+
338 } \
+
339 \
+
340 METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
+
341 return static_cast<otype>( \
+
342 __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
+
343 } \
+
344 \
+
345 METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
+
346 return static_cast<otype>( \
+
347 __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
+
348 } \
+
349 \
+
350 METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
+
351 return static_cast<otype>( \
+
352 __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
+
353 } \
+
354 \
+
355 METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
+
356 return static_cast<otype>( \
+
357 __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
+
358 } \
+
359 \
+
360 METAL_FUNC otype simd_product(itype data) { \
+
361 return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
+
362 } \
+
363 \
+
364 METAL_FUNC otype simd_sum(itype data) { \
+
365 return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
+
366 } \
+
367 \
+
368 METAL_FUNC otype simd_xor(itype data) { \
+
369 return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
+
370 }
+
+
371
+
372#if defined(__HAVE_BFLOAT__)
+
373
+
374#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
+
375#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
+
376
+
377#else
+
378
+
379#define bfloat16_to_uint16(x) x.bits_
+
380#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
+
381
+
382#endif
+
383
+
384namespace metal {
+
385
+ + + +
389 uint16_t,
+ + + +
393
+
394} // namespace metal
+ +
#define uint16_to_bfloat16(x)
Definition bf16_math.h:380
+
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:330
+
#define bfloat16_to_uint16(x)
Definition bf16_math.h:379
+
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:35
+
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:262
+
Definition bf16.h:265
+
Definition bf16.h:54
+
+ + + + diff --git a/docs/build/html/binary__two_8h.html b/docs/build/html/binary__two_8h.html new file mode 100644 index 000000000..9e8fa3dee --- /dev/null +++ b/docs/build/html/binary__two_8h.html @@ -0,0 +1,101 @@ + + + + + + + +MLX: mlx/backend/common/binary_two.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
binary_two.h File Reference
+
+
+ +

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/binary__two_8h_source.html b/docs/build/html/binary__two_8h_source.html new file mode 100644 index 000000000..f6d72e878 --- /dev/null +++ b/docs/build/html/binary__two_8h_source.html @@ -0,0 +1,646 @@ + + + + + + + +MLX: mlx/backend/common/binary_two.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
binary_two.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+ + +
7
+
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12template <typename T, typename U, typename Op>
+
13void binary_op_dims1(
+
14 const array& a,
+
15 const array& b,
+
16 array& out_a,
+
17 array& out_b,
+
18 Op op) {
+
19 const T* a_ptr = a.data<T>();
+
20 const T* b_ptr = b.data<T>();
+
21 U* dst_a = out_a.data<U>();
+
22 U* dst_b = out_b.data<U>();
+
23 size_t a_idx = 0;
+
24 size_t b_idx = 0;
+
25 for (size_t i = 0; i < out_a.size(); ++i) {
+
26 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
27 dst_a[i] = dst.first;
+
28 dst_b[i] = dst.second;
+
29 a_idx += a.strides()[0];
+
30 b_idx += b.strides()[0];
+
31 }
+
32}
+
33
+
34template <typename T, typename U, typename Op>
+
35void binary_op_dims1(
+
36 const array& a,
+
37 const array& b,
+
38 array& out_a,
+
39 array& out_b,
+
40 Op op,
+
41 int stride) {
+
42 const T* a_ptr = a.data<T>();
+
43 const T* b_ptr = b.data<T>();
+
44 U* dst_a = out_a.data<U>();
+
45 U* dst_b = out_b.data<U>();
+
46 size_t a_idx = 0;
+
47 size_t b_idx = 0;
+
48 for (size_t i = 0; i < a.shape()[0]; i++) {
+
49 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
+
50 a_idx += a.strides()[0];
+
51 b_idx += b.strides()[0];
+
52 dst_a += stride;
+
53 dst_b += stride;
+
54 }
+
55}
+
56
+
57template <typename T, typename U, typename Op>
+
58void binary_op_dims2(
+
59 const array& a,
+
60 const array& b,
+
61 array& out_a,
+
62 array& out_b,
+
63 Op op) {
+
64 const T* a_ptr = a.data<T>();
+
65 const T* b_ptr = b.data<T>();
+
66 U* dst_a = out_a.data<U>();
+
67 U* dst_b = out_b.data<U>();
+
68 size_t a_idx = 0;
+
69 size_t b_idx = 0;
+
70 size_t out_idx = 0;
+
71 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
72 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
73 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
74 dst_a[out_idx] = dst.first;
+
75 dst_b[out_idx++] = dst.second;
+
76 a_idx += a.strides()[1];
+
77 b_idx += b.strides()[1];
+
78 }
+
79 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
80 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
81 }
+
82}
+
83
+
84template <typename T, typename U, typename Op>
+
85void binary_op_dims2(
+
86 const array& a,
+
87 const array& b,
+
88 array& out_a,
+
89 array& out_b,
+
90 Op op,
+
91 int stride) {
+
92 const T* a_ptr = a.data<T>();
+
93 const T* b_ptr = b.data<T>();
+
94 U* dst_a = out_a.data<U>();
+
95 U* dst_b = out_b.data<U>();
+
96 size_t a_idx = 0;
+
97 size_t b_idx = 0;
+
98 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
99 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
100 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
+
101 a_idx += a.strides()[1];
+
102 b_idx += b.strides()[1];
+
103 dst_a += stride;
+
104 dst_b += stride;
+
105 }
+
106 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
107 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
108 }
+
109}
+
110
+
111template <typename T, typename U, typename Op>
+
112void binary_op_dims3(
+
113 const array& a,
+
114 const array& b,
+
115 array& out_a,
+
116 array& out_b,
+
117 Op op) {
+
118 const T* a_ptr = a.data<T>();
+
119 const T* b_ptr = b.data<T>();
+
120 U* dst_a = out_a.data<U>();
+
121 U* dst_b = out_b.data<U>();
+
122 size_t a_idx = 0;
+
123 size_t b_idx = 0;
+
124 size_t out_idx = 0;
+
125 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
126 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
127 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
128 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
129 dst_a[out_idx] = dst.first;
+
130 dst_b[out_idx++] = dst.second;
+
131 a_idx += a.strides()[2];
+
132 b_idx += b.strides()[2];
+
133 }
+
134 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
135 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
136 }
+
137 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
138 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
139 }
+
140}
+
141
+
142template <typename T, typename U, typename Op>
+
143void binary_op_dims4(
+
144 const array& a,
+
145 const array& b,
+
146 array& out_a,
+
147 array& out_b,
+
148 Op op) {
+
149 const T* a_ptr = a.data<T>();
+
150 const T* b_ptr = b.data<T>();
+
151 U* dst_a = out_a.data<U>();
+
152 U* dst_b = out_b.data<U>();
+
153 size_t a_idx = 0;
+
154 size_t b_idx = 0;
+
155 size_t out_idx = 0;
+
156 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
157 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
158 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
159 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
+
160 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
+
161 dst_a[out_idx] = dst.first;
+
162 dst_b[out_idx++] = dst.second;
+
163 a_idx += a.strides()[3];
+
164 b_idx += b.strides()[3];
+
165 }
+
166 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
+
167 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
+
168 }
+
169 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
170 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
171 }
+
172 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
173 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
174 }
+
175}
+
176
+
177template <typename T, typename U, typename Op>
+
178void binary_op_dispatch_dims(
+
179 const array& a,
+
180 const array& b,
+
181 array& out_a,
+
182 array& out_b,
+
183 Op op) {
+
184 switch (out_a.ndim()) {
+
185 case 1:
+
186 binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
+
187 return;
+
188 case 2:
+
189 binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
+
190 return;
+
191 case 3:
+
192 binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
+
193 return;
+
194 case 4:
+
195 binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
+
196 return;
+
197 }
+
198
+
199 const T* a_ptr = a.data<T>();
+
200 const T* b_ptr = b.data<T>();
+
201 U* dst_a = out_a.data<U>();
+
202 U* dst_b = out_b.data<U>();
+
203 for (size_t i = 0; i < out_a.size(); i++) {
+
204 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
205 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
206 std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
+
207 }
+
208}
+
209
+
210template <typename T, typename U, typename Op>
+
211void binary_op_dispatch_dims(
+
212 const array& a,
+
213 const array& b,
+
214 array& out_a,
+
215 array& out_b,
+
216 Op op,
+
217 int dim,
+
218 int stride) {
+
219 // Number of dimensions to loop over for vectorized ops
+
220 switch (dim) {
+
221 case 1:
+
222 binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
+
223 return;
+
224 case 2:
+
225 binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
+
226 return;
+
227 }
+
228
+
229 const T* a_ptr = a.data<T>();
+
230 const T* b_ptr = b.data<T>();
+
231 U* dst_a = out_a.data<U>();
+
232 U* dst_b = out_b.data<U>();
+
233 for (size_t i = 0; i < out_a.size(); i += stride) {
+
234 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
235 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
236 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
+
237 dst_a += stride;
+
238 dst_b += stride;
+
239 }
+
240}
+
241
+
242template <
+
243 typename T,
+
244 typename U,
+
245 typename Op,
+
246 typename OpSV,
+
247 typename OpVS,
+
248 typename OpVV>
+
249void binary_op(
+
250 const array& a,
+
251 const array& b,
+
252 array& out_a,
+
253 array& out_b,
+
254 Op op,
+
255 OpSV opsv,
+
256 OpVS opvs,
+
257 OpVV opvv) {
+
258 auto bopt = get_binary_op_type(a, b);
+
259 set_binary_op_output_data(a, b, out_a, bopt);
+
260 set_binary_op_output_data(a, b, out_b, bopt);
+
261
+
262 // The full computation is scalar scalar so call the base op once
+
263 if (bopt == BinaryOpType::ScalarScalar) {
+
264 std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
+
265 op(*a.data<T>(), *b.data<T>());
+
266 return;
+
267 }
+
268
+
269 // The full computation is scalar vector so delegate to the op
+
270 if (bopt == BinaryOpType::ScalarVector) {
+
271 opsv(
+
272 a.data<T>(),
+
273 b.data<T>(),
+
274 out_a.data<U>(),
+
275 out_b.data<U>(),
+
276 b.data_size());
+
277 return;
+
278 }
+
279
+
280 // The full computation is vector scalar so delegate to the op
+
281 if (bopt == BinaryOpType::VectorScalar) {
+
282 opvs(
+
283 a.data<T>(),
+
284 b.data<T>(),
+
285 out_a.data<U>(),
+
286 out_b.data<U>(),
+
287 a.data_size());
+
288 return;
+
289 }
+
290
+
291 // The full computation is vector vector so delegate to the op
+
292 if (bopt == BinaryOpType::VectorVector) {
+
293 opvv(
+
294 a.data<T>(),
+
295 b.data<T>(),
+
296 out_a.data<U>(),
+
297 out_b.data<U>(),
+
298 out_a.size());
+
299 return;
+
300 }
+
301
+
302 // General computation so let's try to optimize
+
303
+
304 // Get the left-most dim such that the array is row contiguous after
+
305 auto& strides = out_a.strides();
+
306 auto leftmost_rc_dim = [&strides](const array& arr) {
+
307 int d = arr.ndim() - 1;
+
308 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
+
309 }
+
310 return d + 1;
+
311 };
+
312 auto a_rc_dim = leftmost_rc_dim(a);
+
313 auto b_rc_dim = leftmost_rc_dim(b);
+
314
+
315 // Get the left-most dim such that the array is a broadcasted "scalar" after
+
316 auto leftmost_s_dim = [](const array& arr) {
+
317 int d = arr.ndim() - 1;
+
318 for (; d >= 0 && arr.strides()[d] == 0; d--) {
+
319 }
+
320 return d + 1;
+
321 };
+
322 auto a_s_dim = leftmost_s_dim(a);
+
323 auto b_s_dim = leftmost_s_dim(b);
+
324
+
325 auto ndim = out_a.ndim();
+
326
+
327 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
+
328 int dim = ndim;
+
329 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
+
330 bopt = BinaryOpType::VectorVector;
+
331 dim = d;
+
332 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
+
333 // contiguous
+
334 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
+
335 bopt = BinaryOpType::VectorScalar;
+
336 dim = d;
+
337 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
+
338 // contiguous
+
339 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
+
340 bopt = BinaryOpType::ScalarVector;
+
341 dim = d;
+
342 }
+
343
+
344 // Can be sure dim > 0 since otherwise we would have used one of the fully
+
345 // contiguous methods above. Except for the case that the flags do not
+
346 // correspond to the underlying contiguity.
+
347 size_t stride;
+
348 if (dim == 0 || strides[dim - 1] < 16) {
+
349 stride = 1;
+
350 bopt = BinaryOpType::General;
+
351 dim = ndim;
+
352 } else {
+
353 stride = strides[dim - 1];
+
354 }
+
355
+
356 switch (bopt) {
+
357 case BinaryOpType::VectorVector:
+
358 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
+
359 break;
+
360 case BinaryOpType::VectorScalar:
+
361 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
+
362 break;
+
363 case BinaryOpType::ScalarVector:
+
364 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
+
365 break;
+
366 default:
+
367 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
+
368 break;
+
369 }
+
370}
+
371
+
372template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
+
373void binary_op(
+
374 const array& a,
+
375 const array& b,
+
376 std::vector<array>& outputs,
+
377 Op op,
+
378 OpSV opsv,
+
379 OpVS opvs,
+
380 OpVV opvv) {
+
381 // TODO: The following mess of constexpr evaluations can probably be achieved
+
382 // with template specializations and overloading. Would it be simpler?
+
383
+
384 if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
+
385 if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
386 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
387 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
+
388 binary_op<T, T>(
+
389 a,
+
390 b,
+
391 outputs[0],
+
392 outputs[1],
+
393 op,
+
394 DefaultScalarVector<T, T, Op>(op),
+
395 DefaultVectorScalar<T, T, Op>(op),
+
396 DefaultVectorVector<T, T, Op>(op));
+
397 } else {
+
398 // opsv and opvs were UseDefaultBinaryOp
+
399 binary_op<T, T>(
+
400 a,
+
401 b,
+
402 outputs[0],
+
403 outputs[1],
+
404 op,
+
405 DefaultScalarVector<T, T, Op>(op),
+
406 DefaultVectorScalar<T, T, Op>(op),
+
407 opvv);
+
408 }
+
409 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
410 // opsv and opvv were UseDefaultBinaryOp
+
411 binary_op<T, T>(
+
412 a,
+
413 b,
+
414 outputs[0],
+
415 outputs[1],
+
416 op,
+
417 DefaultScalarVector<T, T, Op>(op),
+
418 opvs,
+
419 DefaultVectorVector<T, T, Op>(op));
+
420 } else {
+
421 // opsv was UseDefaultBinaryOp
+
422 binary_op<T, T>(
+
423 a,
+
424 b,
+
425 outputs[0],
+
426 outputs[1],
+
427 op,
+
428 DefaultScalarVector<T, T, Op>(op),
+
429 opvs,
+
430 opvv);
+
431 }
+
432 } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
433 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
434 // opvs and opvv were UseDefaultBinaryOp
+
435 binary_op<T, T>(
+
436 a,
+
437 b,
+
438 outputs[0],
+
439 outputs[1],
+
440 op,
+
441 opsv,
+
442 DefaultVectorScalar<T, T, Op>(op),
+
443 DefaultVectorVector<T, T, Op>(op));
+
444 } else {
+
445 // opvs was UseDefaultBinaryOp
+
446 binary_op<T, T>(
+
447 a,
+
448 b,
+
449 outputs[0],
+
450 outputs[1],
+
451 op,
+
452 opsv,
+
453 DefaultVectorScalar<T, T, Op>(op),
+
454 opvv);
+
455 }
+
456 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
457 // opvv was UseDefaultBinaryOp
+
458 binary_op<T, T>(
+
459 a,
+
460 b,
+
461 outputs[0],
+
462 outputs[1],
+
463 op,
+
464 opsv,
+
465 opvs,
+
466 DefaultVectorVector<T, T, Op>(op));
+
467 } else {
+
468 // All ops provided
+
469 binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
+
470 }
+
471}
+
472
+
473template <typename T, typename Op>
+
474void binary_op(
+
475 const array& a,
+
476 const array& b,
+
477 std::vector<array>& outputs,
+
478 Op op) {
+
479 DefaultScalarVector<T, T, Op> opsv(op);
+
480 DefaultVectorScalar<T, T, Op> opvs(op);
+
481 DefaultVectorVector<T, T, Op> opvv(op);
+
482 binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
+
483}
+
484
+
485template <typename... Ops>
+
486void binary(
+
487 const array& a,
+
488 const array& b,
+
489 std::vector<array>& outputs,
+
490 Ops... ops) {
+
491 switch (outputs[0].dtype()) {
+
492 case bool_:
+
493 binary_op<bool>(a, b, outputs, ops...);
+
494 break;
+
495 case uint8:
+
496 binary_op<uint8_t>(a, b, outputs, ops...);
+
497 break;
+
498 case uint16:
+
499 binary_op<uint16_t>(a, b, outputs, ops...);
+
500 break;
+
501 case uint32:
+
502 binary_op<uint32_t>(a, b, outputs, ops...);
+
503 break;
+
504 case uint64:
+
505 binary_op<uint64_t>(a, b, outputs, ops...);
+
506 break;
+
507 case int8:
+
508 binary_op<int8_t>(a, b, outputs, ops...);
+
509 break;
+
510 case int16:
+
511 binary_op<int16_t>(a, b, outputs, ops...);
+
512 break;
+
513 case int32:
+
514 binary_op<int32_t>(a, b, outputs, ops...);
+
515 break;
+
516 case int64:
+
517 binary_op<int64_t>(a, b, outputs, ops...);
+
518 break;
+
519 case float16:
+
520 binary_op<float16_t>(a, b, outputs, ops...);
+
521 break;
+
522 case float32:
+
523 binary_op<float>(a, b, outputs, ops...);
+
524 break;
+
525 case bfloat16:
+
526 binary_op<bfloat16_t>(a, b, outputs, ops...);
+
527 break;
+
528 case complex64:
+
529 binary_op<complex64_t>(a, b, outputs, ops...);
+
530 break;
+
531 }
+
532}
+
533
+
534} // namespace
+
535
+
536} // namespace mlx::core
+ + +
Op op
Definition binary.h:139
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_kernel-members.html b/docs/build/html/class_m_p_s_1_1_kernel-members.html new file mode 100644 index 000000000..f12b0a050 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_kernel-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::Kernel Member List
+
+
+ +

This is the complete list of members for MPS::Kernel, including all inherited members.

+ + + +
device() constMPS::Kernel
label() constMPS::Kernel
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_kernel.html b/docs/build/html/class_m_p_s_1_1_kernel.html new file mode 100644 index 000000000..de2a0c584 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_kernel.html @@ -0,0 +1,144 @@ + + + + + + + +MLX: MPS::Kernel Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
MPS::Kernel Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::Kernel:
+
+
+ +
+ + + + + + +

+Public Member Functions

NS::String * label () const
 
MTL::Device * device () const
 
+

Member Function Documentation

+ +

◆ device()

+ +
+
+ + + + + + + +
_MTL_INLINE MTL::Device * MPS::Kernel::device () const
+
+ +
+
+ +

◆ label()

+ +
+
+ + + + + + + +
_MTL_INLINE NS::String * MPS::Kernel::label () const
+
+ +
+
+
The documentation for this class was generated from the following file:
    +
  • mlx/backend/metal/mps/gemm.h
  • +
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_kernel.png b/docs/build/html/class_m_p_s_1_1_kernel.png new file mode 100644 index 0000000000000000000000000000000000000000..0ed62c014d946e0ba1bb07f7be8739e94bc15dbd GIT binary patch literal 656 zcmV;B0&o3^P)vTJr#LVva2S`&=-}Ys|Ns9r%~qrU000SeQchC<|NsC0|NsC0Hv*f~0006T zNkl3iy_{@y)J&m?)QLzDUCYYUR*>Xdc4ce8T(_rT9_Ajw@FQr7j}s8P-# zUE*4zJFh4+^^oPsOsqF;I~Xs+a+2;EDQi*}cQ(R(Xm)kF8<)E_QN7i@E4K5pUcN3a zTbJgTLtSoB4%5z+Gj|R><(?OJG}}ErJ+3ZmOI1e}+)VFM`u3~&@y9RKQ=P3iN6PK) zaaP0qJF(pH;`ca;Wh=F9${%HXt3rKU?t1!JUHack-#xbfT6el>XJmpn_IS(hP@gY* zoFqS@2>+j@eEkznlCSmQSo*(LlKhoEpK+4>tCEr=W0aI68KWnqh{)9i0Ks|yz~yxd z%E#jln75#OLf!!L9+cVR01&051b`SNB>==IDFGlxNeKWkN=g8TQBndxjFJ)nVw98s z5Tm37fEXnu0K_OM0U$>IFQtgcI_bIFWbDZ06zAs}pIQ05< zV17ZlESQGM{Si04qQGdEJ2^nNYl!)`sdy}?b7+fF*E$5mC&R#K{} q#waOORb!Nts;V)%4W)?4)%6eBqshFLMxG!50000 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::Matrix Member List
+
+
+ +

This is the complete list of members for MPS::Matrix, including all inherited members.

+ + + + +
alloc()MPS::Matrixstatic
init(MTL::Buffer *buffer, MatrixDescriptor *descriptor)MPS::Matrix
init(const MTL::Buffer *buffer, MatrixDescriptor *descriptor)MPS::Matrix
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix.html b/docs/build/html/class_m_p_s_1_1_matrix.html new file mode 100644 index 000000000..8fc4607bd --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix.html @@ -0,0 +1,183 @@ + + + + + + + +MLX: MPS::Matrix Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+ +
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::Matrix:
+
+
+ +
+ + + + + + +

+Public Member Functions

Matrixinit (MTL::Buffer *buffer, MatrixDescriptor *descriptor)
 
Matrixinit (const MTL::Buffer *buffer, MatrixDescriptor *descriptor)
 
+ + + +

+Static Public Member Functions

static class Matrixalloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE Matrix * MPS::Matrix::alloc ()
+
+static
+
+ +
+
+ +

◆ init() [1/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Matrix * MPS::Matrix::init (const MTL::Buffer * buffer,
MatrixDescriptor * descriptor )
+
+ +
+
+ +

◆ init() [2/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Matrix * MPS::Matrix::init (MTL::Buffer * buffer,
MatrixDescriptor * descriptor )
+
+ +
+
+
The documentation for this class was generated from the following file:
    +
  • mlx/backend/metal/mps/gemm.h
  • +
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix.png b/docs/build/html/class_m_p_s_1_1_matrix.png new file mode 100644 index 0000000000000000000000000000000000000000..a14d0bf98ab9ec89bc7677aba5189d2cf8871b16 GIT binary patch literal 665 zcmeAS@N?(olHy`uVBq!ia0vp^i-9~~^1IA5 zhU?xxts81ouD^b!S0}pqk9E7%mFer hwYP+;%euAlbBm9?DLuGqJ1`M3c)I$ztaD0e0swvcJA42D literal 0 HcmV?d00001 diff --git a/docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html b/docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html new file mode 100644 index 000000000..dee3102ef --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_descriptor-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::MatrixDescriptor Member List
+
+
+ +

This is the complete list of members for MPS::MatrixDescriptor, including all inherited members.

+ + + + +
matrixDescriptor(NS::UInteger rows, NS::UInteger columns, NS::UInteger rowBytes, NS::UInteger dataType)MPS::MatrixDescriptorstatic
matrixDescriptor(NS::UInteger rows, NS::UInteger columns, NS::UInteger matrices, NS::UInteger rowBytes, NS::UInteger matrixBytes, NS::UInteger dataType)MPS::MatrixDescriptorstatic
rows() constMPS::MatrixDescriptor
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_descriptor.html b/docs/build/html/class_m_p_s_1_1_matrix_descriptor.html new file mode 100644 index 000000000..1b88da36a --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_descriptor.html @@ -0,0 +1,221 @@ + + + + + + + +MLX: MPS::MatrixDescriptor Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
MPS::MatrixDescriptor Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::MatrixDescriptor:
+
+
+ +
+ + + + +

+Public Member Functions

NS::UInteger rows () const
 
+ + + + + +

+Static Public Member Functions

static class MatrixDescriptormatrixDescriptor (NS::UInteger rows, NS::UInteger columns, NS::UInteger rowBytes, NS::UInteger dataType)
 
static class MatrixDescriptormatrixDescriptor (NS::UInteger rows, NS::UInteger columns, NS::UInteger matrices, NS::UInteger rowBytes, NS::UInteger matrixBytes, NS::UInteger dataType)
 
+

Member Function Documentation

+ +

◆ matrixDescriptor() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixDescriptor * MPS::MatrixDescriptor::matrixDescriptor (NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType )
+
+static
+
+ +
+
+ +

◆ matrixDescriptor() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixDescriptor * MPS::MatrixDescriptor::matrixDescriptor (NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType )
+
+static
+
+ +
+
+ +

◆ rows()

+ +
+
+ + + + + + + +
_MTL_INLINE NS::UInteger MPS::MatrixDescriptor::rows () const
+
+ +
+
+
The documentation for this class was generated from the following file:
    +
  • mlx/backend/metal/mps/gemm.h
  • +
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_descriptor.png b/docs/build/html/class_m_p_s_1_1_matrix_descriptor.png new file mode 100644 index 0000000000000000000000000000000000000000..661b7f470ecb13762aefc4337df516e94f33e8bc GIT binary patch literal 813 zcmeAS@N?(olHy`uVBq!ia0vp^CxAGBgBeI>Z$8xoq@)9ULR|m<{|{uoc=NTi|Il&^ z1I+@7>1SR%c<=xyZhAIs2~du+B*-tA0mugfbEer>7#Ns#c)B=-R4~4sdwbey1p$}t z=^=mri}wfW-EHW6Jo8Lq#F`~42idAGi^mH(aTs$ns{G$4#^Kqwgw<)1vSzc);m^Nv z7_S}ruNt{epI`B=j@#KY_wqbGg}hCg^kzlFo5y9Ii(~7xvo5+m&fI*q+-yHjk;~$^ zggrNGZ=G2byQNq7eD+rj?oIl3dLJjt&%AsBCzIL~p3;TYkSMb+v@$H>5*ZO8c?Mk^jwQbBF--#XO^9c3Yn|`xGL@oM_ zYRb9)(|7MWSk5WkHzRDh!p2u;-$$@A5!SRK#n zC8AF@RPnL9NJ=*w$NypSw7R~&`3uiH_xXAsH?QvyJE^eNv!1W}?IiK((v!ZpO%Kj4 zKj5nLQ*Qq`1zyKZ@x{l_MO}W!>shr}T0!+{tWqlT`@F(s%t2FwcxS45=9lcfaXHia zcaXA~jNw{8$47xPZvN*C65Vb7p|2$9H0!%^#~Qa1hp70RT~8mbaSPf~C#$m8-RfV? z=Z(KDCd9=2T{rd6EbskyCM+?RdV99+_H0%qPx%Wb3X@8_;r`mv*dlO< + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::MatrixMultiplication Member List
+
+
+ +

This is the complete list of members for MPS::MatrixMultiplication, including all inherited members.

+ + + + + + + + + +
alloc()MPS::MatrixMultiplicationstatic
encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *leftMatrix, Matrix *rightMatrix, Matrix *resultMatrix)MPS::MatrixMultiplication
init(MTL::Device *device, bool transposeLeft, bool transposeRight, NS::UInteger resultRows, NS::UInteger resultColumns, NS::UInteger interiorColumns, double alpha, double beta)MPS::MatrixMultiplication
setBatchSize(NS::UInteger batchSize)MPS::MatrixMultiplication
setBatchStart(NS::UInteger batchStart)MPS::MatrixMultiplication
setLeftMatrixOrigin(MTL::Origin origin)MPS::MatrixMultiplication
setResultMatrixOrigin(MTL::Origin origin)MPS::MatrixMultiplication
setRightMatrixOrigin(MTL::Origin origin)MPS::MatrixMultiplication
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_multiplication.html b/docs/build/html/class_m_p_s_1_1_matrix_multiplication.html new file mode 100644 index 000000000..f751a2bd4 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_multiplication.html @@ -0,0 +1,318 @@ + + + + + + + +MLX: MPS::MatrixMultiplication Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
MPS::MatrixMultiplication Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::MatrixMultiplication:
+
+
+ +
+ + + + + + + + + + + + + + + + +

+Public Member Functions

MatrixMultiplicationinit (MTL::Device *device, bool transposeLeft, bool transposeRight, NS::UInteger resultRows, NS::UInteger resultColumns, NS::UInteger interiorColumns, double alpha, double beta)
 
void encodeToCommandBuffer (MTL::CommandBuffer *commandBuffer, Matrix *leftMatrix, Matrix *rightMatrix, Matrix *resultMatrix)
 
void setLeftMatrixOrigin (MTL::Origin origin)
 
void setRightMatrixOrigin (MTL::Origin origin)
 
void setResultMatrixOrigin (MTL::Origin origin)
 
void setBatchStart (NS::UInteger batchStart)
 
void setBatchSize (NS::UInteger batchSize)
 
+ + + +

+Static Public Member Functions

static class MatrixMultiplicationalloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE MatrixMultiplication * MPS::MatrixMultiplication::alloc ()
+
+static
+
+ +
+
+ +

◆ encodeToCommandBuffer()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::encodeToCommandBuffer (MTL::CommandBuffer * commandBuffer,
Matrix * leftMatrix,
Matrix * rightMatrix,
Matrix * resultMatrix )
+
+ +
+
+ +

◆ init()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixMultiplication * MPS::MatrixMultiplication::init (MTL::Device * device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta )
+
+ +
+
+ +

◆ setBatchSize()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setBatchSize (NS::UInteger batchSize)
+
+ +
+
+ +

◆ setBatchStart()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setBatchStart (NS::UInteger batchStart)
+
+ +
+
+ +

◆ setLeftMatrixOrigin()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setLeftMatrixOrigin (MTL::Origin origin)
+
+ +
+
+ +

◆ setResultMatrixOrigin()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setResultMatrixOrigin (MTL::Origin origin)
+
+ +
+
+ +

◆ setRightMatrixOrigin()

+ +
+
+ + + + + + + +
_MTL_INLINE void MPS::MatrixMultiplication::setRightMatrixOrigin (MTL::Origin origin)
+
+ +
+
+
The documentation for this class was generated from the following file:
    +
  • mlx/backend/metal/mps/gemm.h
  • +
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_multiplication.png b/docs/build/html/class_m_p_s_1_1_matrix_multiplication.png new file mode 100644 index 0000000000000000000000000000000000000000..7da3f69a96ea78b09a467274e22e6f2ac0159a49 GIT binary patch literal 1041 zcmeAS@N?(olHy`uVBq!ia0y~yV3Y>312~w0q@1)O2au8u@CkAK|NlRb`Qpvj(*8rs zEetdZB&MHv@!-J&pt$MTuq8k_&XOR%U-KE zqV+HT{U2W=Y!G3_JNstKUgOYJA0?Xv582&#Bq}KA>T-mU#ZyehGBD5MQZN$UygQQr z{QvV@ch~c$pD&)Y=b2H;?iq51-)~IV%KKa;xVZMQah%lj(#?16md%{BZ~uu|ch<~* zav|r0z|*#m$7ja(|Gjj!?7zg8-I3GI?`pf{745=Ea1F5R(qH~aSZ>f?JjwmtW;Z1y?Z zBJ=57vefTo)A#IMl=!Ay=W~|tSGAM2SEQdmyI{LS>d@zJH`MKpeLJ9SzPWv;f-k?{ zWr@3UdTT3scZz-IHMlZ+t;6@JcS@JO3pP(SJiqyzdAfP^3GcjbCgOS9WDWMc3KQj9 zy{6M<$H%P|9aew-UE5_SY|a-~p8WmF#a`LpGoEJM{gRt?ch8pPF;695pFJ77QmkK? zZ_{>#V(Yc$$IhPo5x?3zbG}Sg;kswlYvp&fFW8ds?EMvyeR(^L&2^lQv>X&2%6>;k{Wu6D|L$nFs5Z!D5=fbn@}^ zyE5Fa_tJ0oeNt6ha>wuWS^dJx*Zx@i_&a-3UEsMpHu~?v*6j;E-TG(luf{x;mu~<1 z1E)!QewqTrf5glueOa$*KIsWEuG|3*pR|9>7oUfRF0!cC17= literal 0 HcmV?d00001 diff --git a/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html new file mode 100644 index 000000000..abb569f03 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::MatrixVectorMultiplication Member List
+
+
+ +

This is the complete list of members for MPS::MatrixVectorMultiplication, including all inherited members.

+ + + + +
alloc()MPS::MatrixVectorMultiplicationstatic
encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *inputMatrix, Vector *inputVector, Vector *resultVector)MPS::MatrixVectorMultiplication
init(MTL::Device *device, bool transpose, NS::UInteger rows, NS::UInteger columns, double alpha, double beta)MPS::MatrixVectorMultiplication
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html new file mode 100644 index 000000000..074f379d7 --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.html @@ -0,0 +1,213 @@ + + + + + + + +MLX: MPS::MatrixVectorMultiplication Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
MPS::MatrixVectorMultiplication Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::MatrixVectorMultiplication:
+
+
+ +
+ + + + + + +

+Public Member Functions

MatrixVectorMultiplicationinit (MTL::Device *device, bool transpose, NS::UInteger rows, NS::UInteger columns, double alpha, double beta)
 
void encodeToCommandBuffer (MTL::CommandBuffer *commandBuffer, Matrix *inputMatrix, Vector *inputVector, Vector *resultVector)
 
+ + + +

+Static Public Member Functions

static class MatrixVectorMultiplicationalloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE MatrixVectorMultiplication * MPS::MatrixVectorMultiplication::alloc ()
+
+static
+
+ +
+
+ +

◆ encodeToCommandBuffer()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE void MPS::MatrixVectorMultiplication::encodeToCommandBuffer (MTL::CommandBuffer * commandBuffer,
Matrix * inputMatrix,
Vector * inputVector,
Vector * resultVector )
+
+ +
+
+ +

◆ init()

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE MatrixVectorMultiplication * MPS::MatrixVectorMultiplication::init (MTL::Device * device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta )
+
+ +
+
+
The documentation for this class was generated from the following file:
    +
  • mlx/backend/metal/mps/gemm.h
  • +
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.png b/docs/build/html/class_m_p_s_1_1_matrix_vector_multiplication.png new file mode 100644 index 0000000000000000000000000000000000000000..10e54c8bd6ac07f8af73d45b3c9f0584059b3703 GIT binary patch literal 1145 zcmeAS@N?(olHy`uVBq!ia0y~yU~~ks12~w0#Rfbr z?&fnV|Eu3qkjV{JttOS37aa?B-}!(4xA*_rie2`TpKR6h{U(1`=(Ep#u9J4lIxk!_YfGP$=Tf}* z^zt%!p*!*NcYXXmGv2mq{O_UC{!!)NiHixsx!=~;tn;t%EU2$HvEs0ED(pTkHnrdU z)|r2dwqLi=EH zwzJDGS2O*s2q z^@vvqE(TVQISb}z9Y46d<8@o(_l%##O~z|?2Ke1BZkOepH@AR;^S)=*Dc;h8e0^JO zs}S>AuYw%ce^>rJHT?JeOkv`tx&wDF2D9+Ti?dXJ{E&Eg)m5R5x$C(P?wHIIcKYw~ zBm;{dQ&y}~>6^}S(4afK=^>}Pez@TcZfB1VD|Py`9!0SG-;>(FdsF1f!J`)=e)W8L zyoHaA<$qN7`d151pIJKjM_ftCs>CXTqJ!q&ADyi6{xsvk1-Izfey-H$wB+)u-SYN5 z%V#F%oxho}{`%e-D)$$a|5iWgIX~8D>(TE*Ha~Mze)`oGoxH<$U#3}C*5i}(*R|RE z?7KZb)rSeHob2gR@JtI~@ibbj?UFc>GEYsyQYvSW@?PDm_xucOekhx5 zuYbFK|3>%sg)jfl`nluMvnSuKwWvS`Q7m^Y{xZgh5TMk zPyTc!P+IM$;7to_Pp{XyEY*LdgWej7uV_&^>BWQKMY|?h#0jw+*7>)TW%-k6&UH&q zrf)O0U*Fulc;lCuPmX8G{&DBn_NP<9#wz*zlbOIoV7_}pz`c5_8?j#>9p~RC{@l!c#7$uXYTQ6XH62@)9zAV; R7Fb9yc)I$ztaD0e0swr#1StRj literal 0 HcmV?d00001 diff --git a/docs/build/html/class_m_p_s_1_1_vector-members.html b/docs/build/html/class_m_p_s_1_1_vector-members.html new file mode 100644 index 000000000..31e25f93d --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_vector-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::Vector Member List
+
+
+ +

This is the complete list of members for MPS::Vector, including all inherited members.

+ + + + +
alloc()MPS::Vectorstatic
init(MTL::Buffer *buffer, VectorDescriptor *descriptor)MPS::Vector
init(const MTL::Buffer *buffer, VectorDescriptor *descriptor)MPS::Vector
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector.html b/docs/build/html/class_m_p_s_1_1_vector.html new file mode 100644 index 000000000..9e1039ead --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_vector.html @@ -0,0 +1,183 @@ + + + + + + + +MLX: MPS::Vector Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+ +
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::Vector:
+
+
+ +
+ + + + + + +

+Public Member Functions

Vectorinit (MTL::Buffer *buffer, VectorDescriptor *descriptor)
 
Vectorinit (const MTL::Buffer *buffer, VectorDescriptor *descriptor)
 
+ + + +

+Static Public Member Functions

static class Vectoralloc ()
 
+

Member Function Documentation

+ +

◆ alloc()

+ +
+
+ + + + + +
+ + + + + + + +
_MTL_INLINE Vector * MPS::Vector::alloc ()
+
+static
+
+ +
+
+ +

◆ init() [1/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Vector * MPS::Vector::init (const MTL::Buffer * buffer,
VectorDescriptor * descriptor )
+
+ +
+
+ +

◆ init() [2/2]

+ +
+
+ + + + + + + + + + + +
_MTL_INLINE Vector * MPS::Vector::init (MTL::Buffer * buffer,
VectorDescriptor * descriptor )
+
+ +
+
+
The documentation for this class was generated from the following file:
    +
  • mlx/backend/metal/mps/gemm.h
  • +
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector.png b/docs/build/html/class_m_p_s_1_1_vector.png new file mode 100644 index 0000000000000000000000000000000000000000..e7adea3626eca42f23064dae57f6cb4ab2cb040c GIT binary patch literal 658 zcmeAS@N?(olHy`uVBq!ia0vp^%YZn5gBeIJnfuQXNJ$6ygt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcI5db&7&J22n3nwf4dWEupw+8Nc>45vby~w`%)Bcv;dyWF6_t}3L6cP$ub36Rxo&;y z)w=iZu2oO^GgH2OORdMY^YPPiXPqzmskL_YyEUGlgu+d?d+hUDG3)$-yoq-o|DLn) z`;_H(-p3qQ*wCMG`bz8HT`GUmujoAACHY%o`?5amySX3gLak3_-Ti5CXHMOJy@NG* z*8=>CikDg5etkB0=MSG{ecAiF%2cl$vbFoY?6qF|yX~{|Q@6i=v~1tioOZ*EbKWbY zO=ILc)2_Z@DpHwcdiU9^{Z?Nhr}_2$U4CcIpPzgG-+$x1LUiNm6YF+MKRINfd3)Bk zJl6P^DmOp>+#Yb^*Ziw8T1_hFKRnF%mLE0A=J<2|xhndTXQ@cKW=%>F>0UPT=3B;w z3^|5^`hCuyKmOo4|NcUZ&i;d+y-hY>VpQn6!qT7-B+RhL!?bc3Xc^;B+76wJYbgD_8q&e!Ap-lGDZu-8p$xgiMxmkbV zvYS8uSD&5!SjHrF + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
MPS::VectorDescriptor Member List
+
+
+ +

This is the complete list of members for MPS::VectorDescriptor, including all inherited members.

+ + + +
vectorDescriptor(NS::UInteger length, NS::UInteger dataType)MPS::VectorDescriptorstatic
vectorDescriptor(NS::UInteger length, NS::UInteger vectors, NS::UInteger vectorBytes, NS::UInteger dataType)MPS::VectorDescriptorstatic
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector_descriptor.html b/docs/build/html/class_m_p_s_1_1_vector_descriptor.html new file mode 100644 index 000000000..681f3428e --- /dev/null +++ b/docs/build/html/class_m_p_s_1_1_vector_descriptor.html @@ -0,0 +1,178 @@ + + + + + + + +MLX: MPS::VectorDescriptor Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
MPS::VectorDescriptor Class Reference
+
+
+ +

#include <gemm.h>

+
+Inheritance diagram for MPS::VectorDescriptor:
+
+
+ +
+ + + + + + +

+Static Public Member Functions

static class VectorDescriptorvectorDescriptor (NS::UInteger length, NS::UInteger dataType)
 
static class VectorDescriptorvectorDescriptor (NS::UInteger length, NS::UInteger vectors, NS::UInteger vectorBytes, NS::UInteger dataType)
 
+

Member Function Documentation

+ +

◆ vectorDescriptor() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
_MTL_INLINE VectorDescriptor * MPS::VectorDescriptor::vectorDescriptor (NS::UInteger length,
NS::UInteger dataType )
+
+static
+
+ +
+
+ +

◆ vectorDescriptor() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
_MTL_INLINE VectorDescriptor * MPS::VectorDescriptor::vectorDescriptor (NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType )
+
+static
+
+ +
+
+
The documentation for this class was generated from the following file:
    +
  • mlx/backend/metal/mps/gemm.h
  • +
+
+ + + + diff --git a/docs/build/html/class_m_p_s_1_1_vector_descriptor.png b/docs/build/html/class_m_p_s_1_1_vector_descriptor.png new file mode 100644 index 0000000000000000000000000000000000000000..00b2efd0e245b4a7485a5d74f370f82473449fd6 GIT binary patch literal 794 zcmV+#1LgdQP)vTJr#LVva2S`&=-}Ys|Ns9r%~qrU000SeQchC<|NsC0|NsC0Hv*f~0007| zNklN4wz)Gb z=|vn@4=FdvjE-X)pDT-E61&K9I+HKib0b7g3$o5nbC3*HW+J27QJ(gC8_JAp<$gXm zgYed?>8dO~{2l3i$8r>b6vMqNIU-IhHW@EiY%jhLPky(|1@%(yJy5N&3@dAxVE)SV9P43<1FR z^#XuFG!V<{H4e-`EDOpoFe9A8ZG@+ zF}B}FUl^QxRWH(5NouP0C`rBQWjWTmvP32`VJFK6i~Ko@ojM*@>Tp}3v?_<} z7E99@DoHo$jj8SNfn`k<=BwD`R^=s9!C}f`R-hd#cx@v=sDMR1eb#!yMRk;8lc@q#9cBz;XIu}IP*O%_S=r^zBo{ + + + + + + +MLX: Class Index + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + +
+ +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ +
+
Class Index
+
+
+
A | B | C | D | E | F | G | I | K | L | M | N | O | P | Q | R | S | T | U | V | W | _
+
+
+
A
+
Abs
Abs (mlx::core)
Abs (mlx::core::detail)
AccumHelper (mlx::steel)
Add
Add (mlx::core)
Add (mlx::core::detail)
add_vec (pocketfft::detail)
add_vec< cmplx< T > > (pocketfft::detail)
AddMM (mlx::core)
aligned_allocator (pocketfft::detail::threading)
Allocator (mlx::core::allocator)
And
Arange (mlx::core)
ArcCos
ArcCos (mlx::core)
ArcCos (mlx::core::detail)
ArcCosh
ArcCosh (mlx::core)
ArcCosh (mlx::core::detail)
ArcSin
ArcSin (mlx::core)
ArcSin (mlx::core::detail)
ArcSinh
ArcSinh (mlx::core)
ArcSinh (mlx::core::detail)
ArcTan
ArcTan (mlx::core)
ArcTan (mlx::core::detail)
ArcTan2
ArcTan2 (mlx::core)
ArcTan2 (mlx::core::detail)
ArcTanh
ArcTanh (mlx::core)
ArcTanh (mlx::core::detail)
ArgPartition (mlx::core)
ArgReduce (mlx::core)
ArgSort (mlx::core)
arr (pocketfft::detail)
arr_info (pocketfft::detail)
array (mlx::core)
array::ArrayIterator (mlx::core)
AsStrided (mlx::core)
AsType (mlx::core)
+
+
B
+
_MLX_BFloat16::bits_to_bfloat_struct
BitwiseAnd
BitwiseAnd (mlx::core::detail)
BitwiseBinary (mlx::core)
BitwiseOr
BitwiseOr (mlx::core::detail)
BitwiseXor
BitwiseXor (mlx::core::detail)
BlockLoader (mlx::steel)
BlockMaskedMM (mlx::core)
BlockMMA (mlx::steel)
BlockSparseMM (mlx::core)
BlockSwizzle (mlx::steel)
bool4_or_uint
Broadcast (mlx::core)
Buffer (mlx::core::allocator)
+
+
C
+
Ceil
Ceil (mlx::core)
Ceil (mlx::core::detail)
cfftp (pocketfft::detail)
ChannelHelper (mlx::steel)
ChannelHelper< 1 > (mlx::steel)
ChannelHelper< 2 > (mlx::steel)
ChannelHelper< 3 > (mlx::steel)
ChannelHelper< 4 > (mlx::steel)
cmplx (pocketfft::detail)
cndarr (pocketfft::detail)
CommandEncoder (mlx::core::metal)
CommonAllocator (mlx::core::allocator)
Compiled (mlx::core)
complex128_t (mlx::core)
complex64_t
complex64_t (mlx::core)
Concatenate (mlx::core)
concurrent_queue (pocketfft::detail::threading)
CommandEncoder::ConcurrentContext (mlx::core::metal)
Conjugate
Conjugate (mlx::core)
Conjugate (mlx::core::detail)
Conv2DGeneralBaseInfo (mlx::steel)
Conv2DGeneralJumpParams (mlx::steel)
Conv2DInputBlockLoaderGeneral (mlx::steel)
Conv2DInputBlockLoaderLargeFilter (mlx::steel)
Conv2DInputBlockLoaderSmallChannels (mlx::steel)
Conv2DInputBlockLoaderSmallFilter (mlx::steel)
Conv2DWeightBlockLoader (mlx::steel)
Conv2DWeightBlockLoaderGeneral (mlx::steel)
Conv2DWeightBlockLoaderSmallChannels (mlx::steel)
Convolution (mlx::core)
Copy (mlx::core)
Cos
Cos (mlx::core)
Cos (mlx::core::detail)
Cosh
Cosh (mlx::core)
Cosh (mlx::core::detail)
Custom (mlx::core::fast)
CustomVJP (mlx::core)
+
+
D
+
array::Data (mlx::core)
Depends (mlx::core)
Device (mlx::core)
Device (mlx::core::metal)
Divide
Divide (mlx::core::detail)
Divide (mlx::core)
DivMod (mlx::core)
Dtype (mlx::core)
+
+
E
+
Equal
Equal (mlx::core::detail)
Equal (mlx::core)
Erf
Erf (mlx::core::detail)
Erf (mlx::core)
ErfInv
ErfInv (mlx::core::detail)
ErfInv (mlx::core)
Event (mlx::core)
ExecC2C (pocketfft::detail)
ExecDcst (pocketfft::detail)
ExecHartley (pocketfft::detail)
ExecR2R (pocketfft::detail)
Exp
Exp (mlx::core::detail)
Exp (mlx::core)
Expm1
Expm1 (mlx::core::detail)
Expm1 (mlx::core)
+
+
F
+
FFT (mlx::core)
fftblue (pocketfft::detail)
FileReader (mlx::core::io)
FileWriter (mlx::core::io)
array::Flags (mlx::core)
Floor
Floor (mlx::core::detail)
Floor (mlx::core)
Full (mlx::core)
+
+
G
+
Gather (mlx::core)
GEMMAddMMParams (mlx::steel)
GEMMKernel (mlx::steel)
GEMMParams (mlx::steel)
GEMMSpiltKParams (mlx::steel)
Greater
Greater (mlx::core::detail)
Greater (mlx::core)
GreaterEqual
GreaterEqual (mlx::core::detail)
GreaterEqual (mlx::core)
+
+
I
+
ImplicitGemmConv2DParams (mlx::steel)
Indices
IntOrFloat (mlx::core::detail)
InTracing (mlx::core::detail)
Inverse (mlx::core)
+
+
K
+
Kernel (MPS)
KeySequence (mlx::core::random)
+
+
L
+
latch (pocketfft::detail::threading)
LayerNorm (mlx::core::fast)
LayerNormVJP (mlx::core::fast)
LeftShift
LeftShift (mlx::core::detail)
Less
Less (mlx::core::detail)
Less (mlx::core)
LessEqual
LessEqual (mlx::core::detail)
LessEqual (mlx::core)
Limits
Limits< bfloat16_t >
Limits< bool >
Limits< float >
Limits< half >
Limits< int16_t >
Limits< int32_t >
Limits< int64_t >
Limits< int8_t >
Limits< uint16_t >
Limits< uint32_t >
Limits< uint64_t >
Limits< uint8_t >
Load (mlx::core)
Log
Log (mlx::core::detail)
Log (mlx::core)
Log10
Log10 (mlx::core::detail)
Log1p
Log1p (mlx::core::detail)
Log1p (mlx::core)
Log2
Log2 (mlx::core::detail)
LogAddExp
LogAddExp (mlx::core::detail)
LogAddExp (mlx::core)
LogicalAnd
LogicalAnd (mlx::core::detail)
LogicalAnd (mlx::core)
LogicalNot
LogicalNot (mlx::core::detail)
LogicalNot (mlx::core)
LogicalOr
LogicalOr (mlx::core::detail)
LogicalOr (mlx::core)
LoopAlignment (mlx::steel)
+
+
M
+
Matmul (mlx::core)
Matrix (MPS)
MatrixDescriptor (MPS)
MatrixMultiplication (MPS)
MatrixVectorMultiplication (MPS)
Max
Maximum
Maximum (mlx::core::detail)
Maximum (mlx::core)
MetalAllocator (mlx::core::metal)
Min
Minimum
Minimum (mlx::core::detail)
Minimum (mlx::core)
mlx_atomic
mlx_atomic< T, enable_if_t< is_metal_atomic< T > > >
MLXConvParams
MLXScaledDotProductAttentionParams
multi_iter (pocketfft::detail)
Multiply (mlx::core::detail)
Multiply (mlx::core)
Multiply
+
+
N
+
NaNEqual (mlx::core::detail)
NaNEqual
ndarr (pocketfft::detail)
Negative (mlx::core::detail)
Negative (mlx::core)
Negative
NodeNamer (mlx::core)
None
NotEqual (mlx::core::detail)
NotEqual (mlx::core)
NotEqual
NumberOfElements (mlx::core)
+
+
O
+
Or
+
+
P
+
Pad (mlx::core)
Partition (mlx::core)
pocketfft_c (pocketfft::detail)
pocketfft_r (pocketfft::detail)
Power (mlx::core::detail)
Power (mlx::core)
Power
Primitive (mlx::core)
PrintFormatter (mlx::core)
Prod
+
+
Q
+
QRF (mlx::core)
QuantizedMatmul (mlx::core)
+
+
R
+
RandomBits (mlx::core)
Reader (mlx::core::io)
BlockLoader::ReadVector (mlx::steel)
Reduce (mlx::core)
ReductionPlan (mlx::core)
Remainder (mlx::core::detail)
Remainder (mlx::core)
Remainder
Reshape (mlx::core)
rev_iter (pocketfft::detail)
rfftp (pocketfft::detail)
RightShift (mlx::core::detail)
RightShift
RMSNorm (mlx::core::fast)
RMSNormVJP (mlx::core::fast)
RoPE (mlx::core::fast)
Round (mlx::core::detail)
Round (mlx::core)
Round
Rsqrt (mlx::core::detail)
Rsqrt
+
+
S
+
ScaledDotProductAttention (mlx::core::fast)
Scan (mlx::core)
Scatter (mlx::core)
Scheduler (mlx::core::scheduler)
Select (mlx::core::detail)
Select (mlx::core)
Select
Sigmoid (mlx::core::detail)
Sigmoid (mlx::core)
Sigmoid
Sign (mlx::core::detail)
Sign (mlx::core)
Sign
simple_iter (pocketfft::detail)
Sin (mlx::core::detail)
Sin (mlx::core)
Sin
sincos_2pibyn (pocketfft::detail)
Sinh (mlx::core::detail)
Sinh (mlx::core)
Sinh
Slice (mlx::core)
SliceUpdate (mlx::core)
Softmax (mlx::core)
Sort (mlx::core)
Split (mlx::core)
Sqrt (mlx::core::detail)
Sqrt (mlx::core)
Sqrt
Square (mlx::core::detail)
Square (mlx::core)
Square
StopGradient (mlx::core)
Stream (mlx::core)
StreamContext (mlx::core)
StreamThread (mlx::core::scheduler)
Subtract (mlx::core::detail)
Subtract (mlx::core)
Subtract
Sum
SVD (mlx::core)
+
+
T
+
T_dcst23 (pocketfft::detail)
T_dcst4 (pocketfft::detail)
T_dct1 (pocketfft::detail)
T_dst1 (pocketfft::detail)
Tan (mlx::core::detail)
Tan (mlx::core)
Tan
Tanh (mlx::core::detail)
Tanh (mlx::core)
Tanh
thread_pool (pocketfft::detail::threading)
TransformAdd (mlx::steel)
TransformAxpby (mlx::steel)
TransformNone (mlx::steel)
Transpose (mlx::core)
TypeToDtype (mlx::core)
+
+
U
+
UnaryPrimitive (mlx::core)
Uniform (mlx::core)
util (pocketfft::detail)
+
+
V
+
Vector (MPS)
VectorDescriptor (MPS)
VLEN (pocketfft::detail)
VTYPE (pocketfft::detail)
+
+
W
+
Writer (mlx::core::io)
+
+
_
+
_MLX_BFloat16
_MLX_BFloat16 (mlx::core)
_MLX_Float16 (mlx::core)
_numeric_limits_impl< bfloat16_t > (metal)
+
+
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_abs-members.html b/docs/build/html/classmlx_1_1core_1_1_abs-members.html new file mode 100644 index 000000000..c4d6a962b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_abs-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Abs Member List
+
+
+ +

This is the complete list of members for mlx::core::Abs, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Abs(Stream stream)mlx::core::Absinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Absvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Absvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Absinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Absvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Absinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Absinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Absvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Absvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_abs.html b/docs/build/html/classmlx_1_1core_1_1_abs.html new file mode 100644 index 000000000..a78741814 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_abs.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Abs Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Abs Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Abs:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Abs (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Abs()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Abs::Abs (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Abs::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Abs::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Abs::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Abs::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Abs::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Abs::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Abs::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Abs::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_abs.png b/docs/build/html/classmlx_1_1core_1_1_abs.png new file mode 100644 index 0000000000000000000000000000000000000000..ee6584fe99fc35e9877aa595ed31186df27fcc95 GIT binary patch literal 872 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-4JzX3_Dj46+ec1O(LBNe) zzH{dH`bW{XZ!s;klG=UxY~RBr5eXJN3+hi=KL4oTeJP`5no1?7uIHym7rZ7tQNL@y zcVlp^=A!p& z_Q0U!S2msw?~m1r+EsXpW73sRI?<_9)TDzhP5hO!b(8)2tF}*Q zNA;Ula<8XHO)6>fUTyX1?M9T7nri-pGh$<9j{L>=VF#Q74 z2cam2I4AA{-WiM)oVpD93t1niZeiGiB>Ar)?3`skd*mcHxao_-CaKsLe_x(_>F?!k zyQ|ZcJWcHveqD0><^f-A+cVF-3=?8sJvJ=A9lpPGp46>6?=7}>eyy96`D#zC=Hqu! z#dlO>y=^zug!=3<`|>;V{VwCF$**&!*3QcD@N}P*GxzJ=b78yZYWpdAWnWGDuJqYIW`pkhacV3oVtHn8YwuWc?^RxF4$p8BM^Tw;7wddFVo4-HH^H-^p_oO{NtEU#9JIwe& zH;>`5;j(LQ&6ZrB5jFX0M_hB(QD2r1!i0hb8g4(CqcoSA$f+1_zR00#Y^=X9HSO8c z*tOrqH}+H%-+3C9DOuEeP3F(N&2v+Kb!y5NZ(n*Z%hqbv%)Lieo3{I=9S^d3=J`GR z_U*l?wb^OkM9R0FTKP>eTX$xp_x7W^r;BIYwRfF0JwNr@nKN-88pC$Jn4BZK{^7l% zi)vx+`)s4v-db1s^`_RknCEJFUoRy}mLK#rG5+o8bo$I0?~IuXS + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Add Member List
+
+
+ +

This is the complete list of members for mlx::core::Add, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Add(Stream stream)mlx::core::Addinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Addvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Addvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Addinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Addvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Addinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Addinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Addvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Addvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add.html b/docs/build/html/classmlx_1_1core_1_1_add.html new file mode 100644 index 000000000..a5a7ad12b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_add.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Add Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Add Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Add:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Add (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Add()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Add::Add (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Add::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Add::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Add::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Add::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Add::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Add::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Add::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Add::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add.png b/docs/build/html/classmlx_1_1core_1_1_add.png new file mode 100644 index 0000000000000000000000000000000000000000..39bba292a843ebf0497a42bb6bc0289f5598f73a GIT binary patch literal 874 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B+PJzX3_Dj46+eb~2JLBNe) z-qZN~|0CC@K9ZN2pQ_A@`hBF|%=jr?za+U&c}ex`Fy;j+xvzI(lbZu)I` zmpJ80X6c)v7dyR=PR^TS)TFX?PU`BOODWD?w}P*SY5%TWn>STHOE)!e?W7R1`C*k! ze4DoFzn0&s@{(bCSm>wR+_z>Q)7O5A@VGQDe9JZ7wf>K8y9T7nri-pGh$<9f{L>=VF#Q742cam2I4AA{ z-WiM)oVpD93t1niZeiGiB>Ar)?3`skd*mcHtz3%3+p_45C`Ggv3W?rY`D<+<_u{pQJS`7C#FTfymwpUW~oRRtfryXDyq z&s*E8QY}OFlvG}rAKqJfKV;q7&7tQqHz}&V&D>OKQoeTWx8ilzZe1&_1FbxU z$A-(Uy)|2MeMZ#es~vI8Sw}rtJ_r*E8fduvWRB8YY9gm%y!j%BuCcNH!ql{9PhHnu z-=7%0@k@-N-P)}CcmfqBgzp6WZ+PlfW7S6oMcs=`Wz3VLb^7+@J7R^sxeevG=Z8L9b z>Su?l-oF;MIxMSp+qN#{>z19_`+c$tS=QTjrlrNJZaH&iPs_BKGpAoLGX89|F3d}l b_qcsu>)e>BKaalz<_ZQ+S3j3^P6 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::AddMM Member List
+
+
+ +

This is the complete list of members for mlx::core::AddMM, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
AddMM(Stream stream, float alpha, float beta)mlx::core::AddMMinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::AddMMvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::AddMMvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::AddMMvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::AddMMinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::AddMMvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::AddMMvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add_m_m.html b/docs/build/html/classmlx_1_1core_1_1_add_m_m.html new file mode 100644 index 000000000..16d98170e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_add_m_m.html @@ -0,0 +1,404 @@ + + + + + + + +MLX: mlx::core::AddMM Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::AddMM Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::AddMM:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 AddMM (Stream stream, float alpha, float beta)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ AddMM()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::AddMM::AddMM (Stream stream,
float alpha,
float beta )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AddMM::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AddMM::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::AddMM::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::AddMM::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AddMM::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::AddMM::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_add_m_m.png b/docs/build/html/classmlx_1_1core_1_1_add_m_m.png new file mode 100644 index 0000000000000000000000000000000000000000..5e054780b6598dd29eb1019067f39135616a7ba2 GIT binary patch literal 905 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU7o-U3d6^w7^zV3Uiz{A$g zzw*;>`;VT*iW4TRi{2XBcTIHh1(q{MW_;7Pd~PF`;wja8VUmjeLRHT=nJCXmf0q1S zZeIRM-*-~o1Li3HXFHr@vzNZT`^-Sp*I$MGTzV}^?5vJw>F3O= z>MxX{pKtmX=siiVAk$Q}e$Cp;YX8=R%5!^KeeMandEuMNv^w>*o>@_!yY%O}c>XeV z@}9J(WtvJQr>^HGqjh2DEcw|F?B2oHb9_eB_S+NUlCylnl&g4GiK+`8@XuhZ;M8T% zUEbqm9u7O{ru7nnW>MKQ!VaUbwTNIvDycsl2E#RSz%6PW3i*{YtOjQQ&=)bHBw zeGy!rIqAwL-nHz?>(s>7Zkc(0Zq}h+t8S;tZQEZfX=7{A7rp51!)w)Fy)Hlaw6t$` z_SwS8C3$bp{0iIp;@1NEbvv)euAFx}ZM9tSiMA;v&nB#0y>aQbzV)B;)&}!#mk!jH zy|3W;d|U1hx$Ub%=l6H#lxD5os@?jdC2ac>?`*Gqo9`_4t?^xLCuW~Hl{eQv;^q7E z{MLKYe-|vXj`9_=XNzAg@dtkF>PKf>CK{aZabtoauId zUHP&0*sa-mxA$(#PP_I{IVy90P^RwP-%DT3nps-0z?6NzeQfy6=~0hzr>CB^IZbJhFnsiTu3oxgkTtcid9 zZ{?*uHP%^EH?7!}mL@wVs#SR^{7j9lKTm$tq5uE@ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_arange-members.html b/docs/build/html/classmlx_1_1core_1_1_arange-members.html new file mode 100644 index 000000000..3d8271ae2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arange-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Arange Member List
+
+
+ +

This is the complete list of members for mlx::core::Arange, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Arange(Stream stream, double start, double stop, double step)mlx::core::Arangeinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Arangevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Arangevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Arangevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Arangeinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arange.html b/docs/build/html/classmlx_1_1core_1_1_arange.html new file mode 100644 index 000000000..936235fe3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arange.html @@ -0,0 +1,332 @@ + + + + + + + +MLX: mlx::core::Arange Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Arange Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Arange:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Arange (Stream stream, double start, double stop, double step)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Arange()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Arange::Arange (Stream stream,
double start,
double stop,
double step )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Arange::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Arange::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Arange::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Arange::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arange.png b/docs/build/html/classmlx_1_1core_1_1_arange.png new file mode 100644 index 0000000000000000000000000000000000000000..b5f5fd908820b268b6341df21e452a303bedbbab GIT binary patch literal 907 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU-o-U3d6^w7^zPG_njp7fpP6*%XcBiqh9{YD_-ukQ{-{o%){4ytvPsoqUF(9>0H|Jw<0IKDqB?Q zyLidmYtPo@-wj_Wb@Oa+`h+F(HfLRp+?D0w8GUQ-)zfwH+2zyavvhASnyI=n&u;B6 z7TJ*O-D~BuJTEcotxY@qy*hhyNB`DW$%>xQ^Db&}-_g{q!$PTrID^sJs*eC{yg2R%K8$A-(UJvCZ#eMZ#es~vI8Sw}rtJ_tuK#5r*v z@XlbY;M8T%UEbqm9u7O{ru7nnW>AtZnDZz(!gX>($di4@G_oUM~oe$G7X|9r{o zbp78~_M3WU{cOADdF0}eMN{8BD_%Zp!=6`vjoR1d$2To+w0k!#*LVA!)$>hQr|XyA z3(Y<^dr59~`nBz=+Vz-!7oOdFSN&@7?X0VYv8GFw=*DjCOFDh;Rs8e2QL}CD=(()D zvfpviwzvBpWZ#Qj`>RGKBs)ESYgwwY-NLQApL%C+`*ZH`IkAtXYkx)9Tml8ihD-6s zcZr=9sh zYz)J(WS_OCQ$5xjZdF??xR>o}*JO?d>R7`AGq66%XDpp^sba#JGr%Nr=FFa!X)|Yv z&%gHhm3_@k{>-KFRdch>o?QRLNWB285 znN?xDnQi;ENgvll%`G$Evvss^G-&?chuD`ecLwIi5w@|0kXU?475v94*M5dm>f^&}InjWt& Qz-+_d>FVdQ&MBb@07Wm)i~s-t literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html new file mode 100644 index 000000000..233f33000 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cos-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcCos Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcCos, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcCos(Stream stream)mlx::core::ArcCosinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCosvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCosvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcCosinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcCosvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcCosinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcCosinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcCosvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcCosvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos.html b/docs/build/html/classmlx_1_1core_1_1_arc_cos.html new file mode 100644 index 000000000..b465b208b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cos.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcCos Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArcCos Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcCos:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcCos (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcCos()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcCos::ArcCos (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCos::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCos::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcCos::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCos::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcCos::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcCos::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCos::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcCos::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cos.png b/docs/build/html/classmlx_1_1core_1_1_arc_cos.png new file mode 100644 index 0000000000000000000000000000000000000000..2daeb8d48a2150ddda1796400d1dd81551bf6b6a GIT binary patch literal 897 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUXo-U3d6^w7^=Jr2U;9+ZL zU-{|(|D)5>q?uZlRz`TsM0K-K~lG z%4>D<+Qv2Wvpg?x>aAV&AnPaR+HT2o?(VQ# zXRfzr9efG$fpVvoCs+$XnN9u*pl%QGslV;w`X_+>2=JX3j#-ATus=8aV!m(>ycHNeQIk7;S4-V$Kb>W~Sy*rDv=y_Xj8;c~ z-Foc)Zm;*-N=VGY)^eG!Kgn}`AdzKov7U=XodxlAK=KuDVFa~B822WQ%mvv4FO#pv+!N&jq literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html new file mode 100644 index 000000000..29d5b9722 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cosh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcCosh Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcCosh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcCosh(Stream stream)mlx::core::ArcCoshinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCoshvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcCoshvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcCoshinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcCoshvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcCoshinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcCoshinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcCoshvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcCoshvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html new file mode 100644 index 000000000..54a745725 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcCosh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArcCosh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcCosh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcCosh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcCosh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcCosh::ArcCosh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCosh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcCosh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcCosh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCosh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcCosh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcCosh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcCosh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcCosh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_cosh.png b/docs/build/html/classmlx_1_1core_1_1_arc_cosh.png new file mode 100644 index 0000000000000000000000000000000000000000..2242caeb7f52244dab16af18cfd137121009a01e GIT binary patch literal 909 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GT?o-U3d6^w7^=Jr2U;9+ZL zU-{|(|D)HJaWXmu<*(m;`elq^)~Rk|iOs*K&MBV9ze(k+TZZSP=@%xc=<}^r@vOV_ z*f^@XzRt|E>Y#Z>?3%)MtLt93?@ephdNQl$ZT8G<&(r7sk%)A9Q~2)kss6I5OLoO( z*dEfzym0GA>Gd+Nt0xTnWF03}t=kmp{&JI|s%&QV>gj(g*CtM#pS3pi;+jbz=JUg9 zoBcMG>b#yFHL0Y@d$s1zecSlje?+a`CpM|%iFVXehugCw=CNn`23?+|yRTGO_2qRA zUC&QKQInpiZc#aTdQX()a~oNP`H$EPYTU_cxNzH zaOyJXFJyh7x`km+i&(?-3rrt`q8Q?kB>y*towGd89yuvZ8EC3EC=3>dO;WKx{pGaj z%a4+-`|a&#t8D$#a_wl(=2lnj+-Ikqw4~y|6i~_(@uttlZjVPk!Heb+s?*>&k6gzY4A` zTJY9a=X`0@>Rs{HyHelFEt+}ry + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcSin Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcSin, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcSin(Stream stream)mlx::core::ArcSininlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcSininlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcSinvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcSininlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcSininlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcSinvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcSinvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sin.html b/docs/build/html/classmlx_1_1core_1_1_arc_sin.html new file mode 100644 index 000000000..8f35ffa70 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_sin.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcSin Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArcSin Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcSin:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcSin (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcSin()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcSin::ArcSin (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSin::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSin::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcSin::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSin::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcSin::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcSin::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSin::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcSin::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sin.png b/docs/build/html/classmlx_1_1core_1_1_arc_sin.png new file mode 100644 index 0000000000000000000000000000000000000000..644ab73d9f61e878574c75e7c990b56da38092c0 GIT binary patch literal 895 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU%o-U3d6^w7^zRr8CAi&nn zuX*cx{iDeiW^6r^eWz@ZDb;pc7IW=zRp$Rm#pgadPFr%5L)Y_@P}HO+M;A_2IqARa zZ_l^5-`~|d{eLuFYqotC@H{Hh`u3%jPU}wIkUPD}*8caS+h>H5`io=3>N3kcSKVe_ z8!e&A8zsG4?$*8_Q{$ds_X$hpx!+v1xa8)9Nqt+kUOoM9&$SJw{I_1)lu_onYPSFC zeU9okt)_hq*7dYHxb*6jn)vAU{vR=`_lZp^c_O;@ljH5#5x<(ZsIAOAD|mlb=cF&$ zEz?vgIiaphjaxgXcphJae<93=#yRs|l>=&{=P~F0?r$ww` z`UR#BLQxEHPTU8)GZ-s4bs6-LB<&YWMf3mD6+r)=-}L9Es6OmfKmYx{5A)u>^W)z#ONA$I zkxK0x@&9c9F4)iCzINZLuluh4|GYkGQr+AOlT`FSXFZ+sxrCvni~B&$nU$qGeO{U_ zzjigtZN2B#M`}zz&;kNyDE(nvJ0)lq|D>5Wb6Td&oH_l1k@4q8SE}yTZg^aFc+SRs zOQ-flPh(79a6HQU!*9;p3%(amKfdYBz4u#It<0_Bys2YmoL#wS>l^>H<6*Yb%HKGr zmh@iGi4I@8I4#Y2Uh3+)tNV)Hsv5t}O#k)poefaTZuwTRFL8I7*V-)v^8Q4X-t*sd zHM2Z;_A~dbvnIXIUL9_omD~PnQrOQ0Z+&(CTCEMeo8N!_%Cwm`*9%0YrNud&K6A!9 lW9H1tnc> + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcSinh Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcSinh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcSinh(Stream stream)mlx::core::ArcSinhinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcSinhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcSinhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcSinhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcSinhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcSinhinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcSinhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcSinhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html new file mode 100644 index 000000000..3c5272e30 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcSinh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArcSinh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcSinh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcSinh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcSinh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcSinh::ArcSinh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSinh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcSinh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcSinh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSinh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcSinh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcSinh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcSinh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcSinh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_sinh.png b/docs/build/html/classmlx_1_1core_1_1_arc_sinh.png new file mode 100644 index 0000000000000000000000000000000000000000..728cb98d33b26311eb475e6419bdc830d42310fc GIT binary patch literal 901 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GTyo-U3d6^w7^zRr8CAi&nn zuX*cx{iDeiW^6robEj;PDb;o>+qtGaHS_dWge z7piViIeB_dl;v|9S%&wI*bL;0v+m{vSDjs1x_#Eg?4Y$fIvMJmxDR+|FjjEtGUzX4 zeW1F9VNZ)#!}JSGAB3V9;*cck9kid%dETt!xmghCcd*HU+Mbi&p*G*ZuqL zXRB=e({k--&*oNF?c8Uly|*UZef2G3+1uUvlixG{-IKn>_Rg*F`j>V){{$VsyLH+Q zPpfTZsdqzbOJpU^Xa8!8Uv+)k%~j7!ZcdogS8~&8c68S2`{#CVYZFgSzc%I1gW~Y^AG5-0b@yG;V!w57!>4`ya?G4gpE={5F>_|+$Bf$e znhlQ%_s*JldWHN&uj8uT?w5bBx+e5z=V7xK#rL0X^r_xg9(DDC*`Ah|>RF~~VOzD= zrY0I$Z+o?o<^FE3_fiXIZaWv1zOFefja_eZX#J}Ik&|btZcW^``>wxrMp|k0fdh_ z{<$+}{uNu8n)Xa}%b7EKTBgmM2@HiGuS~0djHcV?XnfiG<}fhBFnGH9xvX + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcTan Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcTan, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcTan(Stream stream)mlx::core::ArcTaninlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcTaninlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcTanvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcTaninlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcTaninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcTanvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcTanvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan.html new file mode 100644 index 000000000..52bf53085 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcTan Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArcTan Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcTan:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcTan (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcTan()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcTan::ArcTan (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcTan::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcTan::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcTan::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcTan::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan.png b/docs/build/html/classmlx_1_1core_1_1_arc_tan.png new file mode 100644 index 0000000000000000000000000000000000000000..61bf8d991ceab8d3da11e185d62fa8119883f9aa GIT binary patch literal 895 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU%o-U3d6^w7^zRr8CAi&nn zuX*cx{iDYg`xuU$Q}c@I+a(&D%>4S}rY-d+EuVi>@V=DMGEJqDQ`hs;qYGY>o~Ylg zugYHjZ?Wek;e88AHy)FH6{EfOzFB$K1aG5jr)_f2pSSyIXViYaLdk)u>a8qEuJr*scA< zGB@RQ#5Mn|Dlb{4hlPI1&6PF#xHGI$N9E-S?Wo&r*VGGNmtUH-WL4<=AJ;smEQwd% zqH@wZ!*kN~3zJmzA7?$C^SOYbrc3%jMbFC8-MlYNmt9LSU37gxRG|>#pBAx(=@*zj z2t_f(IdLEG&S0$I)Me0L$ofEa3&S2H$$t%D=Pdi#BPY4ROzk`+DEXnCSG?b1s8|;Zn-$oOfZ-Z=cshy$(41)oqH`TXB__ zH?HyCe_Xcp)aSxUYyEPsmPIS>TX0qP=h{%!`;~XEs_k1ID%YJ34v?JJ?;oG#dw(FP z)7ZH9gX-EV(Vv{Ac(1kmQCm}a+N|{Fo3BRxtBp2hg*u;p|1(c&$EH2kW{InBPS)1@ z%>B*%_GMM4!p;9?h3aqqt`hxp^|wxM{h5*8wn6rK)?aIkH+S`ChJM|1w{(k~^ZZL% zlj|m}b^C1)lKUuq%c@EF*{@Y}-(L8>WgG8W|D?hL-?oUA-k*2GeNXw;GiU1BFBlnr s7K%zsi*q`C=FI6GQJPClWd1YU37^xvZR(>Pz^uaH>FVdQ&MBb@09$FjhyVZp literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html new file mode 100644 index 000000000..0fc8af254 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan2-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcTan2 Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcTan2, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcTan2(Stream stream)mlx::core::ArcTan2inlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTan2virtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTan2virtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcTan2inlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcTan2virtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcTan2inlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcTan2inlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcTan2virtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcTan2virtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html new file mode 100644 index 000000000..2e293234e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcTan2 Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArcTan2 Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcTan2:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcTan2 (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcTan2()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcTan2::ArcTan2 (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan2::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTan2::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcTan2::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan2::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcTan2::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcTan2::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTan2::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcTan2::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tan2.png b/docs/build/html/classmlx_1_1core_1_1_arc_tan2.png new file mode 100644 index 0000000000000000000000000000000000000000..ff2449809de9c29a3400c3f1f7091052e7e901da GIT binary patch literal 913 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUNo-U3d6^w7^zRr8CAi&nn zuX*cx{iDZLW^6r^eWz@bDb-$fFy`7}rjzkS=PK<^Ofr#PsOlN#jM;^ZtUb34PXAwuarhzAbFkglMC0{iQc1&QJajeqQQP$7UXF{J?i$xj5Nee{WjlJ}O@OE6U^2%dWgK zx}Kkeq9#32-J){x^qwfo=Qgqo^B=Jp$QNhb%?qwNyK-y9*%ziO!|u2+*e_&#pt^-& zPm5T?^b1TMgrXSYoVX8oXE0W9>N4mfN!l-%T72%a1$VSe46w1^v^BVUh-_04%Ggt z?Qv=68|k_uvRPNZKURuv-hB1e)aD;e*FKz>eo5x%!Pj!RkE)}-Zmn1rYPMx>;?lY= z<<>PRyLVq^UaQ90JX^z4{`t}O2jqW!ti1gydh5OIzv}B&0Yi!t7*a;-!_HayvmaQ# zgYnGqIZ@khPl+qe(hW1N@?9lrE_lEnGc2$M)KmV9rBg1|OgM8UO?k_iGkaR5&73*? zdDQys`Dr=ZPoJ4(l<~`Lt;wvSizT9eg0gw*=I`v&IK0jDo&4(!X*YL=UEcaPe*Uz+ z+CQ(OUO!1G)4O{7SJUE^Tk-@~wQJc|A8(#pF=6em(Bli%PCK(^`R&=ucCOp9E-g*n zZqd~_&p(-5*i}AvdQ#f82bZ)w_n!+}Ui&J`?BfpuW9f$%vxM%=U)#OqOYHsSprGlz z%4!`KHuv{GwR?~6>c4y{9{%UYjG2|3y2i%(3scjcse;_D6Q#M-MD{;JiRv8T77yEP Qz|6zo>FVdQ&MBb@0Lo~`?EnA( literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html b/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html new file mode 100644 index 000000000..64caf5fc1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tanh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArcTanh Member List
+
+
+ +

This is the complete list of members for mlx::core::ArcTanh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArcTanh(Stream stream)mlx::core::ArcTanhinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArcTanhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArcTanhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ArcTanhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArcTanhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArcTanhinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ArcTanhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArcTanhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html new file mode 100644 index 000000000..9b55347cc --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ArcTanh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArcTanh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArcTanh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArcTanh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArcTanh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ArcTanh::ArcTanh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTanh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArcTanh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArcTanh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTanh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArcTanh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArcTanh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ArcTanh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArcTanh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arc_tanh.png b/docs/build/html/classmlx_1_1core_1_1_arc_tanh.png new file mode 100644 index 0000000000000000000000000000000000000000..59f4ba4f4fc80b3ffcab15b8c7002012d6b8c63c GIT binary patch literal 901 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GTyo-U3d6^w7^zRr8CAi&nn zuX*cx{iDYg`xuU$Q}c@I+a(&DtoJ&x@teNoa~ru7PpRGulT`E5d?Qou(y|gs9cqzwDmE_tphrj)K)BN7HxWo3GNlx|3bWg3*t=DXO zCRr`*i#1!jJLF5r<0*nFE1zmbr%q9m4!Sh)SI*W=_Uo_OK3yJldgF>TPpxkM)q5P( zZ(7N{ULG~6q{(}=)u*>Nt#Td}ul*F|acN@smTkOi^B;x1Ut;PRx^mi&Yra#K#4B%6 zIq99@IcYk`rH`|o&iP!xP}3!SprU7G>2BVarpvCSm@c}$AgWN9@lT6b!}JSGAB3V9 z;+(h-cxNzHaOyJXFJyh7x`km6lH|XJuydCE?2(h);HEDQo1|j@`NL_;m;Wwz+g+7+ zos?DQ^g8Lt%_F|rYo8S_ov}fu6chwk^P89VmEBRkC6||}z2AiQUfii2yRXTZdmaxu zuQR)=Ur+XLVQBPs@27XHqMyzH2ZLI4dh+(oSM{D>j%qiK$#+?KWu1cO@lAO@{mhk<{TI-J*ug!ISExqmP_kU^a-*e2%cP$Q8 zU05uv`ttFfALbw0|NY#P`!#g!_qG4#*I%9V<#o$6l}gEQ?dNmanI8n-VK|oTv-Wnb z$9lu9YO4kJvR&<(&GA4TH6UskhGITRZ1pT-8@oz0L2^rLtQ8?AsT9sfmf__-7ZsyFL3y$M<#9X5M@c mccyp7%$cBISURQcu{^WL9M;_u7u|pvhQZU-&t;ucLK6Vpv$(AQ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html b/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html new file mode 100644 index 000000000..b3ae98b73 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_partition-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArgPartition Member List
+
+
+ +

This is the complete list of members for mlx::core::ArgPartition, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArgPartition(Stream stream, int kth, int axis)mlx::core::ArgPartitioninlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgPartitionvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgPartitionvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArgPartitionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArgPartitioninlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArgPartitioninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArgPartitionvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition.html b/docs/build/html/classmlx_1_1core_1_1_arg_partition.html new file mode 100644 index 000000000..ecdce6668 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_partition.html @@ -0,0 +1,391 @@ + + + + + + + +MLX: mlx::core::ArgPartition Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArgPartition Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArgPartition:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArgPartition (Stream stream, int kth, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArgPartition()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::ArgPartition::ArgPartition (Stream stream,
int kth,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgPartition::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgPartition::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArgPartition::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArgPartition::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArgPartition::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArgPartition::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_partition.png b/docs/build/html/classmlx_1_1core_1_1_arg_partition.png new file mode 100644 index 0000000000000000000000000000000000000000..8dcfb003d24a60c695b8e610cc069a2949a22060 GIT binary patch literal 936 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GW1o-U3d6^w7^zRhbk;Bn2* zTXO3C|0B&JLeHBYm2Ej|;%~|HDUq*l^FP-f!*lFPp6A-KCoU1auH>oBw$@=%*`Z(N zQM>Qlm-d|Zfo&~&_L+r!TMzxqed*SEXVS*_OIs$*H>#BNRP>sY8dbL%w{l{ljBKMyv|OsUd*&ubwr^LqUOiP8mmTr6Jt{gfWVOeuw9A*&UnoS^ zM$XTg_=Hhy?W!lYZ@n`67~?I^J?TxQn|6LitJTdp-rx^r(1yK(k%Z`6j`S9PA3PhEe0^0b5Xd%3kOw|{+^J?-RA$K$$D znt2Mp!`5GkeYN!Vfr_}R<~5V!qSL+4gG0jPrq!(3*{lD~S-o|a)H-t~?XB?wC-%LS zj(fzn^~+C-6W11O%i0~&;h(|lZMS{Z6ZZAhTRPWgzDzD$Hr0HKUEL0Y%)x>S9T7B(xSpD8P`-^fLA@t=;xzlF_!itb&;<-M>s& zr+c^d!&>poP}aHgUwxl`@bs2d2Oj6}YE>?{TK?Z@wbsvV+oHDmd5fR1-twvKntI^e z#gimrzn@iGb?AipRu_xQLV`Qn;*Zw$UwiP|)#c3|&FH`VeH|Th_kRo4Rr+Kb + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArgReduce Member List
+
+
+ +

This is the complete list of members for mlx::core::ArgReduce, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ArgMax enum valuemlx::core::ArgReduce
ArgMin enum valuemlx::core::ArgReduce
ArgReduce(Stream stream, ReduceType reduce_type, int axis)mlx::core::ArgReduceinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgReducevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgReducevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArgReducevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArgReducevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArgReduceinlinevirtual
ReduceType enum namemlx::core::ArgReduce
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArgReducevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html new file mode 100644 index 000000000..1600965aa --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.html @@ -0,0 +1,418 @@ + + + + + + + +MLX: mlx::core::ArgReduce Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArgReduce Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArgReduce:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType { ArgMin +, ArgMax + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArgReduce (Stream stream, ReduceType reduce_type, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + +
Enumerator
ArgMin 
ArgMax 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ ArgReduce()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::ArgReduce::ArgReduce (Stream stream,
ReduceType reduce_type,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgReduce::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgReduce::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArgReduce::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArgReduce::output_shapes (const std::vector< array > & inputs)
+
+overridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArgReduce::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArgReduce::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_reduce.png b/docs/build/html/classmlx_1_1core_1_1_arg_reduce.png new file mode 100644 index 0000000000000000000000000000000000000000..ac897a69d7edec8e1e8739c3ee6c241cd955b53e GIT binary patch literal 932 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GWro-U3d6^w7^=Jr2U;9+ZL zU-{|(|D)HJaWXmu<*(m;`sJRaD^tYN4Jv+5ol`uIf0N2tw+zoo(=SX?(dS#M;#qg; zasHNX|Ni7m`qGqtVOPXszS4D5x9!#C4?STjdHc9g?)md}KMgmv+<1Iv`BeS4TAro5 zFI2SzURtnqvToILE3<b^^<&R$lR?yZgc_IzvG_CF!nd&SNI)ql>s zDq1pW?c=R=D^)$!cWjxZegFR2!;1T^tvc;AsjtR$wM<5I=hMmw*Hlc`-h322Rl{>X z+=UBOJ>#BTJ!Sd)CBqL9^M;=@f?t=-zLf2^cK6kuxXD?rs!TsPbs6**vOZAV!my`B ztYP{ErVm0<3~^4}2fQ;FE084rut!-ww~(Uvm7u#;?CW zZ>p?(s=h@pZH}%^RHkvf+$z@BVcYvtZ|(bS^waLb`D(Y@hu7AAo&9d|>FHMQ)@?oF zduiiV?bqvGEq#5UBJQ$zP4M1rXTxllpKhB{vgGu&n|F$~|9V>!^}3+_cJqpS zJ*V8%_CnSqgB32GzbxYa$p3J*x66;NU489$*#C3>TUGvAQ8$D%mzv0{7;nDGp=)fc zzc4lJ*;A+2ulF4n^$wjGc{L`{e&y6Y?`eJ8zWc6K{js>X)>|_E``%ie%d|1n r=f8F4&3wVAv@~Ec^3q)T@&~iT?OBsH9FJcG%ufuSu6{1-oD!M + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ArgSort Member List
+
+
+ +

This is the complete list of members for mlx::core::ArgSort, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
ArgSort(Stream stream, int axis)mlx::core::ArgSortinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgSortvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ArgSortvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ArgSortvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ArgSortinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ArgSortinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ArgSortvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_sort.html b/docs/build/html/classmlx_1_1core_1_1_arg_sort.html new file mode 100644 index 000000000..510709ed7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_arg_sort.html @@ -0,0 +1,386 @@ + + + + + + + +MLX: mlx::core::ArgSort Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ArgSort Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ArgSort:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ArgSort (Stream stream, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ArgSort()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::ArgSort::ArgSort (Stream stream,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgSort::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ArgSort::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ArgSort::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ArgSort::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ArgSort::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ArgSort::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_arg_sort.png b/docs/build/html/classmlx_1_1core_1_1_arg_sort.png new file mode 100644 index 0000000000000000000000000000000000000000..523bf16a1f393ec78f0cd6f3df8783e4a793ab99 GIT binary patch literal 919 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUlo-U3d6^w7^zMk}0i-&D? z|EekX{&)I6+Gv>YCa`FdPV{Qcsy77_Z#${fpR|1bQ^EUEPRlfvN={wRPmeBmO?skk z`_C=={darcNp%mIqr|iCEIhaM(zm?RzM@6G$I2p?z5RJp-QI1QK=qv8TRts}Rb81U z9@TH)CAGz@?C#dLL8eAM!R`~5%v+svHE>ssho}3kTdz|0uQy%uDR}KxomruiUOl_K z%Kw7rw!YNAtEZ{VeQ+tuYv1~{ht>YA3zg^gwEEN&cJsnFm1(~nqb6V3@=P?|s&mqp z;~YTO2}MnMqPj)pwao;`9>nljK-?+njL(-((LQn5e% z<+SO`kLF!<`~PQ7y7Gy4ZMS4PcX!yWGt*~h9sITGcIwP+`)g;`$iI00*X8!%wY6U- zzngrz-Rj-Ctw($>ZQQE;dflsquMbqjT{f=?+Pm#+nC)}3i-?;rlu-`AfV`+nNYn+^IYuOCHUK9h6e{-5C7v~RIar_Y=LCLGPB dCbIt-S9s0-$L+{}2bhHzJYD@<);T3K0RRE?)z1I` literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html b/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html new file mode 100644 index 000000000..63e99012a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_as_strided-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::AsStrided Member List
+
+
+ +

This is the complete list of members for mlx::core::AsStrided, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)mlx::core::AsStridedinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsStridedvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsStridedvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::AsStridedvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::AsStridedvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::AsStridedinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::AsStridedvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided.html b/docs/build/html/classmlx_1_1core_1_1_as_strided.html new file mode 100644 index 000000000..d9002ed44 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_as_strided.html @@ -0,0 +1,413 @@ + + + + + + + +MLX: mlx::core::AsStrided Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::AsStrided Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::AsStrided:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 AsStrided (Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ AsStrided()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::AsStrided::AsStrided (Stream stream,
std::vector< int > shape,
std::vector< size_t > strides,
size_t offset )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsStrided::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsStrided::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::AsStrided::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsStrided::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::AsStrided::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsStrided::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_strided.png b/docs/build/html/classmlx_1_1core_1_1_as_strided.png new file mode 100644 index 0000000000000000000000000000000000000000..7224d5d45d763d83f5d9719c6855149d687d13b9 GIT binary patch literal 917 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GT)o-U3d6^w7^KAf~#OMq>5 z|Eeo{{tL$^i!d6_eDdT?(QBWszBPQw%!i-pTRyjuPw|}9dts7_{z6sHIGHHVNq?67 zPB$rgX+KM4?*sh|Tgk_MrE8~t+qrhG*3)egzf*awuRZ73XCE0XQ#ozJCwFVFpu6)j z?q1f&ym0Hz*R{3NSE(PiDP{9~RkYR*fA@BMPYdo`}jF0yFEmHX_ z1#+E}_oO{7(^M*FhHF2cvz+-s^c{w0$$o2Z=X$NT+$y$OaLe#-uBsgroeG*0G7L-7xVxvaf% zUcvMEw)`Ix%U6WX@9(}UbtNo2oa-a++QgI7uTA;$p!B@iN8`1>BKBO;lD;)>!zKH9 z<<>Q8%F36qN1avT0R~F#T=W0z|1R{OKOR;e`nq`a|MUJ^RsLFmLh9MoQSQ_;uOrOWA&Fk7l{8_uTqOh3N-oSYQvRKg>~@OHJfej5pup&^0#JUznQq>}l(@ z-{;PD{GPe$UG8GIf5&r_vT3$jgB?_a-`dHP&-?X_QC zxuNcP1w!8YMofDzj$()dD$Kb<3GEz+^LX=Jd*~L0*|w|CpKNW*aYS TeZ3W!ffzhp{an^LB{Ts5 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::AsType Member List
+
+
+ +

This is the complete list of members for mlx::core::AsType, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
AsType(Stream stream, Dtype dtype)mlx::core::AsTypeinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsTypevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::AsTypevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::AsTypevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::AsTypevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::AsTypeinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::AsTypeinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::AsTypevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::AsTypevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_type.html b/docs/build/html/classmlx_1_1core_1_1_as_type.html new file mode 100644 index 000000000..00bb61d5a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_as_type.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::AsType Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::AsType Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::AsType:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 AsType (Stream stream, Dtype dtype)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ AsType()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::AsType::AsType (Stream stream,
Dtype dtype )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsType::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::AsType::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::AsType::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsType::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::AsType::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::AsType::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::AsType::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::AsType::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_as_type.png b/docs/build/html/classmlx_1_1core_1_1_as_type.png new file mode 100644 index 0000000000000000000000000000000000000000..4b919c285526347e6d43cb082353c9bcdff54692 GIT binary patch literal 918 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVQo-U3d6^w7^KFoWqz~ja* zAN2Hl{iEAyoQzIJ+h18Nd)>MC(h&>wV{7V9T0Z}&;C(5lWtvJQr>^IxM;E*%JyEy) z=a&8c^Ljfz;8 z{ueyA_T~OvJ56QogG*Ua``)iTtX8)^RG!<@>QhhH%?saDru|&7#xu+9bC>>H7tdd& zPTrIDz+7pxKJ1*OKl_2*I~dO#pA)tH_LR8dth;N|zRU{I_ETz@e}Um4)7We-?wQYIy`m6rC*r)54-iysX zIeW?DjPPsISGBJ*KUR2l>%+;fw&reqHDjGr(4~p%qS7}m-Tv!(<@Rd}+i&--nEK1y z<5FRc^uFh^Sy!JwUJ})OFYB$BvfV<{oS#Osrpep9%My;iysH0bc~q+NZQ-zIzyId7 zf6ppUuUc?bB(Xq9_2pZepY|V|?d|lV|E{?9I;{Tr@~BC5b1zI%(f^$Fbk64zhCSWV z2Wrl&EZyz<(scQ?yIF4QJ-0qmVfukHNYKLT4||m6QWJR<oa_vhQj7Spdcwp{B@{W@uVWKNpdgP?5B ziD&v2|B8^?9IqAreL>i2USnhFg4ngc0;}&>d8b_~-c!8hn)=pR>-6sL@?EpKqLyb* z)a|;Mkkyyp-O9?|QB{2Q%rT3tv!=bTULCIO%Fw! zqM0|XGyk8SF>~%G$F<9f($d6h{8r!dG~Vp}@00)5nK#opK_2*=wRFm*l0VG4*XEe9 Tf9Kl_%s~vEu6{1-oD!M<4;agi literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html new file mode 100644 index 000000000..f6c72e0c0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary-members.html @@ -0,0 +1,121 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::BitwiseBinary Member List
+
+
+ +

This is the complete list of members for mlx::core::BitwiseBinary, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
And enum valuemlx::core::BitwiseBinary
BitwiseBinary(Stream stream, Op op)mlx::core::BitwiseBinaryinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::BitwiseBinaryvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::BitwiseBinaryvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::BitwiseBinaryvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
LeftShift enum valuemlx::core::BitwiseBinary
Op enum namemlx::core::BitwiseBinary
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
Or enum valuemlx::core::BitwiseBinary
output_shapes(const std::vector< array > &inputs) overridemlx::core::BitwiseBinaryinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::BitwiseBinaryvirtual
RightShift enum valuemlx::core::BitwiseBinary
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::BitwiseBinaryvirtual
Xor enum valuemlx::core::BitwiseBinary
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html new file mode 100644 index 000000000..d0a522c28 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.html @@ -0,0 +1,422 @@ + + + + + + + +MLX: mlx::core::BitwiseBinary Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::BitwiseBinary Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::BitwiseBinary:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  Op {
+  And +, Or +, Xor +, LeftShift +,
+  RightShift +
+ }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 BitwiseBinary (Stream stream, Op op)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ Op

+ +
+
+ + + + + + +
Enumerator
And 
Or 
Xor 
LeftShift 
RightShift 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ BitwiseBinary()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::BitwiseBinary::BitwiseBinary (Stream stream,
Op op )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BitwiseBinary::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BitwiseBinary::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::BitwiseBinary::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::BitwiseBinary::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::BitwiseBinary::print (std::ostream & os)
+
+overridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::BitwiseBinary::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.png b/docs/build/html/classmlx_1_1core_1_1_bitwise_binary.png new file mode 100644 index 0000000000000000000000000000000000000000..0e74367fc64baff8773ceeb39e1af8c7c8b01490 GIT binary patch literal 937 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUBo-U3d6^w7^zU_Ohz{ASV zzVg#=`;We#8Oe_ZlZIk{diO68?0 z)P>V4w-%kNv|@M=dxs&x?%9=Hcb9#c6`Z|s*2V0gwHBQW57f6X9PrLyY?yw5i9vrM zD?^+UH^ZJ5F@_3GU4{=rQ49}MkrmWCXrl$e^f>K_OXi#YDe~P_?{;(gPcx58pE%cc zcWiET)s}ts+C}$9{MR2*hhN>E{@|&b^!C)%FTZ8~^7JnF>HK_cRQ5fWZ->`inSRCj zyYyp=;@lr9d#{CgA2+$GAY45wYwpo$_p;)j-;IjpzukMm>#elPCi&xS*Y`N<*4p{4 zGJT<&9WM2eb8X`3>4{NwA9k$vt?~7i?{<$$b$-jK{kdxY&4*EMpZV}_@eyQGpQv*G z+~ohv_Fo=<-gLD;Yj57y>bg~)pI*0!c~;E~2YPBbGsC^3ybRBh{np;j^;&PaRcy87 zUbU-TlQ|gb@CFHbn0;d366BR>H^IdvO?iup%bu2L9UapbUYoq^^~IahUQW#w6ID|x zIvKs?)vVQ1wR~Kj2(6vk_dQl!>*%*>PqTS%KYLbrx;Ls{w>>W6@U;_ai`Oo`7`1q^ zhy9i7vo^+Am94Y!&n|u4ul=m;oW+`H=Prjht9YJYuzga8oav6tgOk_%U%u6T(M=c6 zq+`=k`L`vi_I|Q^tqbI(Ze`nB`X>2h(A*W*jc$Hx+jqrC*WM$^yfC0`-H-QIw4xKw z_u9qyUfz;rwd4KT->jFLg#S&SAAEY-Yn`aD$L4=Sx;tv7Ul0=fXtX}eOVjtb{TYp! V{aa^=9R+4822WQ%mvv4FO#pZu!9xH5 literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html new file mode 100644 index 000000000..a2891b770 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::BlockMaskedMM Member List
+
+
+ +

This is the complete list of members for mlx::core::BlockMaskedMM, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
BlockMaskedMM(Stream stream, int block_size)mlx::core::BlockMaskedMMinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockMaskedMMvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockMaskedMMvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::BlockMaskedMMvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::BlockMaskedMMinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::BlockMaskedMMvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html new file mode 100644 index 000000000..f447b19b8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.html @@ -0,0 +1,365 @@ + + + + + + + +MLX: mlx::core::BlockMaskedMM Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::BlockMaskedMM Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::BlockMaskedMM:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 BlockMaskedMM (Stream stream, int block_size)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ BlockMaskedMM()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::BlockMaskedMM::BlockMaskedMM (Stream stream,
int block_size )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockMaskedMM::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockMaskedMM::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::BlockMaskedMM::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::BlockMaskedMM::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::BlockMaskedMM::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.png b/docs/build/html/classmlx_1_1core_1_1_block_masked_m_m.png new file mode 100644 index 0000000000000000000000000000000000000000..8e5e7a8a40ca12f5d3e6d44b574fc1deea89beea GIT binary patch literal 966 zcmeAS@N?(olHy`uVBq!ia0vp^OM$q9gBeI}JboKUGDrvbgt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcF&db&7H?n-Zba|%S+OyN{ze?bk-}oi}$(+w`9o0_G7M`Z^(m7(%68`JD!k%&J zd;fiX)_4B1>m>1?wx^pd^Ox;Pn0ZZqzT6cq@0HhXi-fHXZ;sDjCSrPCBJAIW%^sT$ z?|iDQEo?nK){B2_^s0L$pI`8(>^$BY^=gUl+7m}~zr}2OSD$w?cSh{#r#v$wRc=1i z^}ajXdh4CC6Dw6c=Y8l~`f}d4n7pz-cY?|#Jl}n0KE3Jpw&SOZm%r#Vnf^1WUrzm` zy}FL)rS=Ufp7Mt#snk~Vq@J_9FUwGn#?Wy4Gso$~il%4EI_@XgGwcoOh8p|=|Kd;2Dc)y$LZz}8WITuWq%VorjV3-RUH0Ly<;$<@r`OiZ z)1UO_v+L>Rj?G7-PH(BZ{$;^4{nm*5;E1)poO1D73X8YR%9FgT-FxsU|LpSPMNw)K zZ=Y7XYv%{B1(Huk<$gDJR};oyRx# zJ1_w5`RQ6ezjSJv^Pe_V&nP|}-|r=1_P0LHk9@_a9J#(26g1_hi)-)8wC@r>yWaKq zr7KVN9QV95*Sx;k^U^2Y)BURZ6Myb}QeHMk#a;;(Xl@|?Su;F1!@#ipx!|)2T-+= zl3kAO{VnsNHkjwiMYQ{eOR|UL7{4t(t(KpinWsL@blc1@oy)tfd|7bm%u=oGma(xe z(X%d8`sjzpm8ZpQm0=FQtg=YIbarqEdvd1p^Yx+s%`bZ$`(&yXCVJ*d(E6)ayi`t3 z`xdh*`s78ubpJ!|$}1!ElHT@x*9}#i{(hyl&didjxm){nN{!!K*;KA4?H#OCnlVdu zyLA2COC|o*dgj*q>h)$R8-{H>U4Coii!?yEmu$&bJdL`@X%Ob>Pd7ol84T vP|oyp^>bP0l+XkKxcbb4 literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m-members.html b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m-members.html new file mode 100644 index 000000000..2cee30d11 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::BlockSparseMM Member List
+
+
+ +

This is the complete list of members for mlx::core::BlockSparseMM, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
BlockSparseMM(Stream stream)mlx::core::BlockSparseMMinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockSparseMMvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::BlockSparseMMvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::BlockSparseMMinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::BlockSparseMMinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::BlockSparseMMvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.html b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.html new file mode 100644 index 000000000..3941e4915 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.html @@ -0,0 +1,361 @@ + + + + + + + +MLX: mlx::core::BlockSparseMM Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::BlockSparseMM Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::BlockSparseMM:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 BlockSparseMM (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ BlockSparseMM()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::BlockSparseMM::BlockSparseMM (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockSparseMM::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::BlockSparseMM::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::BlockSparseMM::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::BlockSparseMM::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::BlockSparseMM::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.png b/docs/build/html/classmlx_1_1core_1_1_block_sparse_m_m.png new file mode 100644 index 0000000000000000000000000000000000000000..a9ca56c7b6c18830703f1680ec3a3af4f6770faa GIT binary patch literal 952 zcmeAS@N?(olHy`uVBq!ia0vp^3xT+UgBeKf(5bEhQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GW9o-U3d6^w7^zU^CWz{BeA zx8&6O|3{_=F_qb6th{}9&ARX&kAsD5>i6!SuzbGKQSD?Wr}LyQf^8}YS(^N&))pXU6)D9deZN#^QtT5%e&@!6wh4$^_!9BKchsC zNp^=OsMNM7dj9$$+-CXQ#+czlU><{k>1FAAd7HKU@^^6ft+Uvl_{E5!g43DdgJ2uO z17#k@1MU(`4gCfz4DyK_40eYE7;0J+8GdjeDLDQ=V@C10;AWpm!ot81Q0DQxbo_w9 z#3!%UJX~IU>F@gKwUYLFliqwxZc`V0-W_>3WM8`Mik;Dnr*e6%vo^8o>%ZFJzD@b= zH=p%eSl-*`-0OR|%SGw-X|;>}SAN~@uA5VNE!OkA?6rXQa8o7W=v7jEN#Bbj>Tink z&R(2f?sfaB{NxkgwA7uR+nQgIHCvJRTqf$`-wSaDDw*pRE_^xV+^O{4f95TGw~V{( zci(NfyT_0Je#IeE9q>2&g~Y6?Ip2*u>*jR-kM^AP`gfb(!S~Wtto2FLJuf**Oj^=! zpyDatop{djJP$+76ln&ZvtQb_-+t3|Sz51Yvi7-riM>1wao9oxXOMlgms&RE(iMfy z4j&(ht`1-@3JS`vedzjl>)oo$+t_bs3r;llI#U#W_`F0|$i7vOi@5*U@`E}|-+cf=9Z}ZGoFMJ(D zPhAP{Uzw#cX$dfKyjoWL=J=M$ucF(2efge!BS?b%l#MFu- zWqjUucV6f5U0xD&J63MKbFt6rVv8A9es^t5T<>`L+X~Su`>yVdxfi~7PxtnV>1|3% u@lXDS9{TVe=#@S0aQ}2CdTE|LW53#NiSH-Q&Z)pW#^CAd=d#Wzp$PzvKf)mZ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html b/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html new file mode 100644 index 000000000..0f8b3cce0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_broadcast-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Broadcast Member List
+
+
+ +

This is the complete list of members for mlx::core::Broadcast, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Broadcast(Stream stream, const std::vector< int > &shape)mlx::core::Broadcastinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Broadcastvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Broadcastvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Broadcastvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Broadcastvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Broadcastinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Broadcastvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Broadcastvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast.html b/docs/build/html/classmlx_1_1core_1_1_broadcast.html new file mode 100644 index 000000000..4b455d2c2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_broadcast.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Broadcast Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Broadcast Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Broadcast:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Broadcast (Stream stream, const std::vector< int > &shape)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Broadcast()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Broadcast::Broadcast (Stream stream,
const std::vector< int > & shape )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Broadcast::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Broadcast::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Broadcast::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Broadcast::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Broadcast::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Broadcast::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Broadcast::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_broadcast.png b/docs/build/html/classmlx_1_1core_1_1_broadcast.png new file mode 100644 index 0000000000000000000000000000000000000000..080f3c4554799abdfe0d910acad98ef3142ea320 GIT binary patch literal 905 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU7o-U3d6^w7^zMk}0i-&D? z|EekX{&)62)=qe%ar2GNayQMYHw6=KJNf*ZRDABS zYsNj3tk(9$mgVja=_z?UMNnnsQ@!ZaDQeO|muBwT7Iky}^;cm(mtKn!yQ<}x`Z;r| z`U|D#=bP$Qs(Om=*s^L;e)j7vo$I%rN>=ntzjyJL`4+p*-?i#%Jy%72?UJAC;`uAR zWtvJQC)AaxF>B`(%keeje`Jv8H{2@MuekTvm04@tzDR{=`zbZdzrgfCD2gG@iTi+e z24e-ME`$C;)(5Iv81}S?HB3iHp4Pu`%JTU~1@B84Fw={!c};qfKCf=Knyvlb7nlED zQd#*-eM??ijP9MSm(uL+i*R2HzpM9o%j37-PhNJg&hIR1iT-^`-FNbL{nR_o*?qH@ z6t51S7JJnBdgH%6tJeQsQj|M)ZIRz(Ud@+2lh>}^S)IFe{pZ}Z!MxX{1GTp1PgqiX z+w4!r?Ws+u^W@e0SAE@i_5btbQIq~zE>!i5J9ZW5wTBEp#P}ON zN(8_DC~+ym-UWYAE_|?z!@fJ0rrP!?UW!X{z)@$0)uGg%;^`5j6WBx|GMl` zP1~6>zhZd4?%Kb)bpJE|t(S^lZLZ&^DYow94Pfvn+5di0l(Jt(IX7MOLF+Ygq2Un-ZtJ@rjK->+v?&_Q_t~FWFN=>w<2j&DnA-^t8P3 z=C)c4=w(!cl2yaZabtobk?>IkR%+mLRW8i+{{k9kV&-7S#6wGYx~MtDnm{r-UW| Dxv{?# literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil-members.html b/docs/build/html/classmlx_1_1core_1_1_ceil-members.html new file mode 100644 index 000000000..2cce4f65c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_ceil-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Ceil Member List
+
+
+ +

This is the complete list of members for mlx::core::Ceil, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Ceil(Stream stream)mlx::core::Ceilinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Ceilvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Ceilvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Ceilinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Ceilvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Ceilinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Ceilinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Ceilvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Ceilvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil.html b/docs/build/html/classmlx_1_1core_1_1_ceil.html new file mode 100644 index 000000000..4d11251f0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_ceil.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Ceil Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Ceil Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Ceil:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Ceil (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Ceil()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Ceil::Ceil (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Ceil::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Ceil::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Ceil::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Ceil::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Ceil::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Ceil::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Ceil::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Ceil::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_ceil.png b/docs/build/html/classmlx_1_1core_1_1_ceil.png new file mode 100644 index 0000000000000000000000000000000000000000..7894fb3ecc936689693918209af7312a1447a1cb GIT binary patch literal 864 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-KJzX3_Dj46+eckt3frqV~ zf90p&_8&br?POTEGJ0!l-!;+2myTEnCzt)7RDABSt60m5zXsKIq}{%0_Q7N?AyG=;*(F}@BXT(ONwGL z&Sp-zB3b(G)1__RN9X3vGip}Z8WXzu%%u=#FRS41VcNfI*XB*Xe>Lj%qA1msGInb} zvD{sGZRc9~EYC}_F;StXx9=^zb>w)|(~T3B7@yC0D;+ieNbNn%>`AKv_dfDg51drj zdts7_{z6sHI4AE(dwNz+Ek1Xc@q=C-!(+o`*WQ{fxjrLm^3{&G=B%ThEFXlU7~-6` z4|r!VR&eSv=r3e_pt^-&Pm5T?^b1TMgby~LMm6GAw&*!u= zKZw13dZ7+iXtlQ^9AZt0Uu?_bxNz30xGtyAV-%KConOwi$-+xrgiXNKzD z|8~tYxBK4@H8J;f#X|n+dABe1&v7x{oG%1(fp^BtnIIo7oznJLUSDSwZtrvvi> NgQu&X%Q~loCID-}tP}tM literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled-members.html b/docs/build/html/classmlx_1_1core_1_1_compiled-members.html new file mode 100644 index 000000000..bf0ad0373 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_compiled-members.html @@ -0,0 +1,108 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Compiled Member List
+
+
+ +

This is the complete list of members for mlx::core::Compiled, including all inherited members.

+ + + + + + + + + + + + + + + + + + + +
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)mlx::core::Compiledexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Compiledvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Compiledvirtual
is_equivalent(const Primitive &other) const overridemlx::core::Compiledvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Compiledvirtual
lib_name() constmlx::core::Compiledinline
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Compiledvirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Compiledvirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Compiledvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Compiledvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled.html b/docs/build/html/classmlx_1_1core_1_1_compiled.html new file mode 100644 index 000000000..78b508b6e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_compiled.html @@ -0,0 +1,493 @@ + + + + + + + +MLX: mlx::core::Compiled Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Compiled Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Compiled:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Compiled (Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::string lib_name () const
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Compiled()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::Compiled::Compiled (Stream stream,
std::vector< array > inputs,
std::vector< array > outputs,
std::vector< array > tape,
std::unordered_set< uintptr_t > constant_ids )
+
+explicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Compiled::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Compiled::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Compiled::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Compiled::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ lib_name()

+ +
+
+ + + + + +
+ + + + + + + +
std::string mlx::core::Compiled::lib_name () const
+
+inline
+
+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Compiled::output_shapes (const std::vector< array > & inputs)
+
+overridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Compiled::print (std::ostream & os)
+
+overridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Compiled::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Compiled::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_compiled.png b/docs/build/html/classmlx_1_1core_1_1_compiled.png new file mode 100644 index 0000000000000000000000000000000000000000..4f12eb20e1be08551a228728beea4ff60e4b1282 GIT binary patch literal 546 zcmeAS@N?(olHy`uVBq!ia0vp^6+j%o!3-pyx;pL#QqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=E8GJzX3_Dj46+?d@w(;AvUz zylLk5`bWiqTmI~*sxIGfGI2)7jcjTQ)JnriANM8G=@ATV0tUF#V5B^YA8F`KK#_o?x9^Z19{vf#ct+H`=uyy9eQu$z^@@L~~=p@)^RL)Un-V3TK1z3&E<(u?ek`}a;39;{_kB!`&8@u zGE3JdPm<3JD}R%`s?R5A<(|hy)>psX`F5YR+WIcbdd=_$pIB5Q7yjzgw?4Ny`oWo) zr?TN6L}D9#KRrGsYyI0S=Ix)|^`Wb?=SHuawj}s`)iOMp)He6?s|g7Id?Urdr& W$?>H>(xQQ}$>8bg=d#Wzp$PzV007DW literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html b/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html new file mode 100644 index 000000000..9847df5c1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_concatenate-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Concatenate Member List
+
+
+ +

This is the complete list of members for mlx::core::Concatenate, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Concatenate(Stream stream, int axis)mlx::core::Concatenateinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Concatenatevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Concatenatevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Concatenatevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Concatenatevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Concatenateinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Concatenatevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Concatenatevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate.html b/docs/build/html/classmlx_1_1core_1_1_concatenate.html new file mode 100644 index 000000000..b96f54f64 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_concatenate.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Concatenate Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Concatenate Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Concatenate:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Concatenate (Stream stream, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Concatenate()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Concatenate::Concatenate (Stream stream,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Concatenate::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Concatenate::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Concatenate::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Concatenate::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Concatenate::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Concatenate::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Concatenate::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_concatenate.png b/docs/build/html/classmlx_1_1core_1_1_concatenate.png new file mode 100644 index 0000000000000000000000000000000000000000..3404621456d96f7206bcf97da9f0f3e1ca5feacc GIT binary patch literal 914 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GV&o-U3d6^w7^=Jr2U;9+ZL zU-{|(|D)5F_#TRm+WzWoVbuB=+^moNcw+WXDn9qwaoUob9J-#LgrX)rIl6GF%1QsI z|8CjS|Lyg>#Qm@7+GNvrA=jfa>$k5BUQ#2HT%VJeUVH9AX_|N8v&chz{9m;^Q>zzk zT^4w0!PT9o7uHVinwdDys99yJ&FZk3OINl{xe`2oZP>TxTicfZ33^?qvnq7bsb`n9 z@-Mh<{gwPT$a|7rL1w7?KHKQKYWuEft#4KNx@U1|PDZV8|Gx!mJXb}1?UKLi;`uAR zWtvJQC)Aaxack!k&*N*zFJzw4Z@D#ZyXM|!SEAOWeVG-aou||={{qtop(utpC+-8@ z8H^R2x(xaYSs$owVc63m)-W9*d0PL%Da+?S6}&Iyz)ZiCt?Kz{bE|#qI)21ros@K4@Xl@P!Xz(;ID5UyTvaMsJ!kE=(tT^M&V5*=?6ve8@1!f+ zYY(5_ zGe7<17OB+EHUH25XQBN3{j{35E5GEpcKwVmF#`RDduOAPB) zZq@ZjtJRi^zEu3;PM%urWVGUPMW#(%GFmcyUz3;y}tF7^FOE8 za#iMhhx_Wg(#GDre{V`(y-Dew>Mv1`S#-8YaeoAdwy^2<6WDd zRQSBl|7*tXH+8>%y)3e+t2uM#p8A$EXZEyAn>lm(1ta6nn=+P8xm5CpG4J3kyH8d= R!ob|a;OXk;vd$@?2>{QQ($4?@ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html b/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html new file mode 100644 index 000000000..3439e670e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_conjugate-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Conjugate Member List
+
+
+ +

This is the complete list of members for mlx::core::Conjugate, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Conjugate(Stream stream)mlx::core::Conjugateinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Conjugatevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Conjugatevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Conjugateinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Conjugateinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Conjugateinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Conjugatevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate.html b/docs/build/html/classmlx_1_1core_1_1_conjugate.html new file mode 100644 index 000000000..f7f0de1cf --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_conjugate.html @@ -0,0 +1,382 @@ + + + + + + + +MLX: mlx::core::Conjugate Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Conjugate Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Conjugate:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Conjugate (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Conjugate()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Conjugate::Conjugate (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Conjugate::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Conjugate::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Conjugate::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Conjugate::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Conjugate::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Conjugate::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_conjugate.png b/docs/build/html/classmlx_1_1core_1_1_conjugate.png new file mode 100644 index 0000000000000000000000000000000000000000..08be44bd1fb347166935e0ebe48e2c691c15f1e8 GIT binary patch literal 929 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUVo-U3d6^w7^zV3Uiz{A$g zzw*;>`;VTNVj3n~Te~YZ?6|g@C&!G=gSYoTne(|~f@-FTllP=OEz?vg3o}$bKTUqN z|K6RK|L&>0bviAx2Ubmx~je}v~$vz zY>?|Xbv-``MNN8=8nbpzu^eB6e<8C(zv0%r?TUMkU75AU?Tb{1wwB-l{|v?oPF)85 zg{%)$w=nE!5o?%!f$4)#6hoX7_W^H&5~5n{AT^H zy14vz_2eb9@-u8D3QY^Id(Hk?6zmgWU%Ff8aNqj-?Ed<<3cqK#@7TINCR1vU{gguf zwW@iO&aIsnSti=OPXE{ARr}s~J-uTU{d5L67}TP}lfQ4iT6b>wR=@Ne_d0@B{%um3 zdwbhM_dDytBIU(zzM8Z4R-S!_enx5kpL?s6??=A7sJ3r;=(MB8YipffbKQR5|NShV z{LQ?qU5i6i7X|}Er7-S~{Ewyb^S7_93wd3;djIEOUC;e~8J?4-7j7*&S82guca%55 z=Gc{8ciX;51!r%Px|kiXc7+RrJ)xk1h1-7z?WI$WHK(MRNi9rGd#1YO%$Yql*Qz2b z^^}Y^cdgdoUsNBKE%|=lmsMf2($D-o?Q2+n_GnpTo9Ii;chl}=9{9V7ZT9c$r_DA` zzw5nLHO|O-ds^?J?AjZDrH*zSmlYkQQtkk z+rD}CD0ppI{uv$9Ip?nluibV3%$lb&!%Cmr&bM2cySV!F`=39L+|=?~7u@N3_N{+u zexb-LuRYUWYER8b`{wvcH0E_}`0K3mXU=HtTM(*OurWO?O|<^g@>^%#gmdT`8|y#L gS~}%Y!5`-3e`kFOIJL?Qn3)(nUHx3vIVCg!05urS_W%F@ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution-members.html b/docs/build/html/classmlx_1_1core_1_1_convolution-members.html new file mode 100644 index 000000000..c09690551 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_convolution-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Convolution Member List
+
+
+ +

This is the complete list of members for mlx::core::Convolution, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)mlx::core::Convolutioninlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Convolutionvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Convolutionvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Convolutionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Convolutioninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Convolutionvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution.html b/docs/build/html/classmlx_1_1core_1_1_convolution.html new file mode 100644 index 000000000..ac64adff4 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_convolution.html @@ -0,0 +1,390 @@ + + + + + + + +MLX: mlx::core::Convolution Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Convolution Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Convolution:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Convolution (Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Convolution()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::Convolution::Convolution (Stream stream,
const std::vector< int > & kernel_strides,
const std::vector< int > & padding,
const std::vector< int > & kernel_dilation,
const std::vector< int > & input_dilation,
const int groups = 1,
const bool flip = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Convolution::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Convolution::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Convolution::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Convolution::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Convolution::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_convolution.png b/docs/build/html/classmlx_1_1core_1_1_convolution.png new file mode 100644 index 0000000000000000000000000000000000000000..853ab7ab4adc387b35aedf2bb0fa2701addada34 GIT binary patch literal 907 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU-o-U3d6^w7^KAf~#OMq>5 z|Eeo{{tL$^i!fe1^X}xkhu3C`>=lwZR@U$)zUW+~-HAyi(hF5RK z`STg?F4dcQRvqTgxO=8>`L1hTzjv>ltMxRj=kKOA*{_yO^UJS^xX)8vQBO2YSz`S5-mBF5`0O30^|x-zy?AC)i23}m z+9thKrT4DoZ&i88G(Bwd)8BhbqmF!!`nqw#665%bTHLqh9l0bw#njU`JMLq&df=pg zQZN^`OjD_x8Ls_&&T{4l!8Q!flKs{mPxe}Gxm9enj4k ztoEPyw!ZpY7`XP}y{xyV-2OCOD>yy2#qv+ie|`>~>| zd>MO`uOJ&RQ0&f$*E83@oPU1%+PaX}zr+5|ufIC!%XLm&&re3{!_HayvmaPq!EomI zoT&ceQ{sxVbi<6Rd{>F83m)*t8XlN|^;G_Xm*&!+3NvS(b<3DJvyxNS*jRshRQ$Hu zdGf}arTcz+zc>4l9F_fNwrB6OnKxDUhTV!!c^mS1d(@lhrfanPuSUC1Z@8YC{D0g1 z>ZSQf`CG5eNGZ#^ntHaF=c@On{=Q|sYO))@UtP6)&Z>!L)+}Fp&GX2@(zvVk7gw48 zY`wNk+-}v?^}REf_IG|`nrQcrs+kL9N`?LOxsjRvFN%eb9y$_RIdw;dm z9lpY9hx?f`XZkw8ow_hJ?OAHv+9^S^`u;O`B+NP``}?gOFxxPAy85}Sb4q9e065IR AzyJUM literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_copy-members.html b/docs/build/html/classmlx_1_1core_1_1_copy-members.html new file mode 100644 index 000000000..3cb46c818 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_copy-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Copy Member List
+
+
+ +

This is the complete list of members for mlx::core::Copy, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Copy(Stream stream)mlx::core::Copyinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Copyvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Copyvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Copyinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Copyvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Copyinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Copyinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Copyvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Copyvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_copy.html b/docs/build/html/classmlx_1_1core_1_1_copy.html new file mode 100644 index 000000000..f233d09c2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_copy.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Copy Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Copy Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Copy:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Copy (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Copy()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Copy::Copy (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Copy::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Copy::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Copy::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Copy::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Copy::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Copy::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Copy::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Copy::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_copy.png b/docs/build/html/classmlx_1_1core_1_1_copy.png new file mode 100644 index 0000000000000000000000000000000000000000..2f4f36d04325b9a6ebae742f7adc17e52b2dfca0 GIT binary patch literal 892 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GWdo-U3d6^w7^zCQQLiicHQ zTjpoLRM4}0ZhUIR($cvrA0LP>I#+3PVv>o}LRHT=C+|snEVfKi`RRGx z|NYIEf6P=~GRj}r^|{CT>dvsJ`?+?SL3&3c_=|7E&A0tgmF~Ck`OHJx^uGpAI#uZY zdX}5#+`#90`fFE*+)ixuos)vf`iLjcS)$C%z~o z>%M~L@r`;v=51dc`mTR+R%zAhZMUU9a<0ugoxeEi=ed&aVjoS{7Dd}!x+S*dUgGVa zb^9Y9?!9$(-n)!Y*B5VGJgcnz|FPA*JpVj(b${00ys!6vuAa7JzA#lo$V+qSM}?U) zd)+c-&ICq)v9bR74BNj<*L73VzS+#bGV5{t=c0_I*{`y%)CBo#xsqgiCegLBYxB=n zA_dkKF}6{~`%=nejwa44-oK^h=AX2uTKCSp5eeVw)$eRNY39}|S6BV5Ul?lkbLX|k zb=5(uW2KF~d5_=A3jh5!D)h8`TK3gzFLrIWTfS8{W3O-4za{5bpL(fU6d zU!FPh>yN`}k3#(e(E|8LdVGy7aX&Xf$-UOJ`ivAog5xnetnXMpk$gQu&X%Q~lo FCIFzmxAp)4 literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_cos-members.html b/docs/build/html/classmlx_1_1core_1_1_cos-members.html new file mode 100644 index 000000000..f07b0309e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_cos-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + + +
+
mlx::core::Cos Member List
+
+
+ +

This is the complete list of members for mlx::core::Cos, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Cos(Stream stream)mlx::core::Cosinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Cosvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Cosvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Cosinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Cosvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Cosinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Cosinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Cosvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Cosvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cos.html b/docs/build/html/classmlx_1_1core_1_1_cos.html new file mode 100644 index 000000000..1d2e6907d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_cos.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Cos Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Cos Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Cos:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Cos (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Cos()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Cos::Cos (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cos::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cos::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Cos::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cos::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Cos::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Cos::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cos::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Cos::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cos.png b/docs/build/html/classmlx_1_1core_1_1_cos.png new file mode 100644 index 0000000000000000000000000000000000000000..4724c19a70821cc52ec8dfdb5cc6b7a30befd5a7 GIT binary patch literal 875 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-)JY5_^Dj46+eK={g0*@QN zyz$&U|ApgUhq1cMF3Xy3J)1Ax%k)Sx$L{)*md`&bcwfqBnWj?7sq6Xa(FLzbPt@<) z@7;R&PrAn?VZ94hk&pdK*G!#OK3m=Ol)K@y^_DrS&x_UUH(q|@lTYI7{;H`LruW%JCkTpc!Z>B_b#Tb93D8}{vY)LlLMD^*)uw|WIx)9JN?VMsc{!J>qaMP>Ts(98t`*7Lx<-eDy z^Pkx}Ps-Z2uvDmUOTksI?w>ZAVm5K5KW81@wmrUixpeLw|82T=A4mOMqFM1b@Y&rh z&-QrU>aR+zTv_qrv77z6UDs-Zc5Yi2CV4r;+3VKjkk=YzZ`Xb+wOhOO^}{M6AixpwZw`c9%Dj0v^)wVxYtERoTd6y%6|FTy9QRB6x&aZ{Ho!kEV zEMNZBymKY)tF@e-t9V^HUimTq!SsJW_vDtYy!Jcn|M~q{p1(?+yeIAHSv|G*++oHK znt2S54VPVeYqsS2jHt<1JK~zNj(W0u5XKrFn1S_^IZAV>iJXe@=8GJ<#>V;!Q`4S3 z4P84u{!!hUGiTx+XRWpTSE|1}Y}fLA?&l^ya_o=Ni7$=_oR_qXW!6mFqpMAA3j)2< zo@(rS)}A+A&OU7avq>{=3ckylDz`gJ#W;KA>d+svXWBkyTD$3F^jiP?MUktHXKwwr zckA7q8N#cl}P?&Hvjt~%*yV*FbTXU^ + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Cosh Member List
+
+
+ +

This is the complete list of members for mlx::core::Cosh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Cosh(Stream stream)mlx::core::Coshinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Coshvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Coshvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Coshinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Coshvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Coshinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Coshinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Coshvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Coshvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cosh.html b/docs/build/html/classmlx_1_1core_1_1_cosh.html new file mode 100644 index 000000000..385f8d6d8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_cosh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Cosh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Cosh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Cosh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Cosh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Cosh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Cosh::Cosh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cosh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Cosh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Cosh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cosh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Cosh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Cosh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Cosh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Cosh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_cosh.png b/docs/build/html/classmlx_1_1core_1_1_cosh.png new file mode 100644 index 0000000000000000000000000000000000000000..69fffddab9dd68ef284091892a88bda3d72302eb GIT binary patch literal 888 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GW7o-U3d6^w7^KAiMQfya$s z-q`Qn|IYQ-rZFweE4lLY+_8r}H`>0KwA|hQWX|WB396Z9PTrIDv`kZ}EX+{#{4_al z{>59b|JbR#WVc)JHKEt~>dvrR_i|^3XkJgsiTAz{H{bS0S-R)O=R4b`%D>g}{93V~ z)O+!ggBiuwr*Dg0^^WKB6-kwq&jYu<(up-)vP3yI|Ldmv*QLsSuDo_j>?~0Ihs>+$ zFP20d-}W!idy-#4rm61z@7EHA>#wcqZ&i`4bX%R1Q7v@3Cg7Ti>EgKO-f<;-yH=W~`bKZuQCc$Vz9_H?S(ddscr!i=kYSBa_%9`Mg#tl-pT&|k>#e^$uRr*-_}uPn-O9-&GH*})3i^8C>oJRU zJI!i>_HJ7nCVM&5*~{uu=xd#AU&8)dy}xGlrC_(<)}a4vo>qOihrjP}Ui+(NPDu8J z=(Vw#N7T2p>BhTV{?vTj_STW>TV5r9oE&!Z!nc;GKkN42ez^D6-I;P3rmhlmTs*6+ z{r~a)xnw_o`dYhHUw2*o|9O4Xq`J8mCaLIu&I0A0{`y0=> V?i_Oq#DRH)!PC{xWt~$(6992>z`_6k literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p-members.html b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p-members.html new file mode 100644 index 000000000..eee5452e6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::CustomVJP Member List
+
+
+ +

This is the complete list of members for mlx::core::CustomVJP, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
CustomVJP(Stream stream, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun)mlx::core::CustomVJPinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::CustomVJPvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::CustomVJPvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::CustomVJPinlinevirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::CustomVJPvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.html b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.html new file mode 100644 index 000000000..8ed54c2a1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.html @@ -0,0 +1,320 @@ + + + + + + + +MLX: mlx::core::CustomVJP Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::CustomVJP Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::CustomVJP:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 CustomVJP (Stream stream, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ CustomVJP()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::CustomVJP::CustomVJP (Stream stream,
std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::CustomVJP::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::CustomVJP::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::CustomVJP::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::CustomVJP::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.png b/docs/build/html/classmlx_1_1core_1_1_custom_v_j_p.png new file mode 100644 index 0000000000000000000000000000000000000000..32bf6e7e9d196669ef036904122ec5e2bbe1cd42 GIT binary patch literal 575 zcmeAS@N?(olHy`uVBq!ia0vp^tw0>W!3-oBzPn8VlF|V_A+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=E8?JY5_^Dj46+eZB680!K@` z;mr^K{A;Q%Gu_HHe}3s%3$Ji4>#xoo`vbi+w{y3uK^O*)0`?KJi;ai`i zrOr5>X0fuBm#bQ}b#G0)%eMIC5@t^nkCq+Nd&SlI_HA?27ugCg$4QRwHDbz-?R(+( zWX_>71J7h}jax5&SYLY9yomGVnYS-a$ts?-4+pzP^(eEZm*&$v#slVGnHp|?<>VH5-#~rpIDCFY`c&&E!EE%LVHk8I}W zZAYIa`W?A5(PN|j6Z;*jvI2rsT32O-1X*wiPTc*Kwcjpp#!0D3pr~f>boFyt=akR{ E0N$Gk-2eap literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_depends-members.html b/docs/build/html/classmlx_1_1core_1_1_depends-members.html new file mode 100644 index 000000000..8cb2e5547 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_depends-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Depends Member List
+
+
+ +

This is the complete list of members for mlx::core::Depends, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
Depends(Stream stream)mlx::core::Dependsinlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Dependsvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Dependsvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Dependsinlinevirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Dependsvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_depends.html b/docs/build/html/classmlx_1_1core_1_1_depends.html new file mode 100644 index 000000000..45bef7a88 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_depends.html @@ -0,0 +1,316 @@ + + + + + + + +MLX: mlx::core::Depends Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Depends Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Depends:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Depends (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Depends()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Depends::Depends (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Depends::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Depends::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Depends::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Depends::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_depends.png b/docs/build/html/classmlx_1_1core_1_1_depends.png new file mode 100644 index 0000000000000000000000000000000000000000..8c1a6319a8271ffeaf01b756e6339b70e85d5ae0 GIT binary patch literal 548 zcmeAS@N?(olHy`uVBq!ia0vp^Wk4Lj!3-pS_Y`sfDd_;85ZC|z{{xvX-h3_XKeXJ! zK(jz%`k5CG9y|bwo1P6@0+iz{3GxeO0P?}WoN4wI1_s9Uo-U3d6^w7^^7g$};AuJD z9P;y@dL5^4^V#!iyLAIz1{!h&%;ukM$>7A1xJSUFXp7i173Ytrcsef+9}s&#jwzK6$lXR#JcZ=E%}Dr3s5~T{4?~qiyNAhdNXD=U#l0 zy=LoO({DLD%0F4I-{R_F9>W=`KKY4(u8wEfwL_Cq($5}l6F5}Bcp&~OL&L{e&9yem z3Z5sJzv6kP@}~aTlUCM-dr=V#FV+AV=ddz%C<{*f*tB9*-?c-lR+U}5o={)e`@%c) zs`kzW-#3;Qzlyl|{d?t{87mggxUSwd^Vhu1GbGBlGYHQQbN6WM6_qnq(^L4}y zhxyYU{x$NC@}OtiOL}%{Sk0&6V@_9oy3-+Lu?GU5i^2nEP{^-G&%dVUrE# zf_~Pn_BGpm^Y*gI`uDFzYzopr0NBh9mjD0& literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html b/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html new file mode 100644 index 000000000..bd2f77317 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_div_mod-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::DivMod Member List
+
+
+ +

This is the complete list of members for mlx::core::DivMod, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
DivMod(Stream stream)mlx::core::DivModinlineexplicit
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::DivModvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::DivModvirtual
is_equivalent(const Primitive &other) const overridemlx::core::DivModinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::DivModvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::DivModinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::DivModinlinevirtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::DivModvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::DivModvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod.html b/docs/build/html/classmlx_1_1core_1_1_div_mod.html new file mode 100644 index 000000000..d4455f3ed --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_div_mod.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::DivMod Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::DivMod Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::DivMod:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 DivMod (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ DivMod()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::DivMod::DivMod (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::DivMod::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::DivMod::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::DivMod::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::DivMod::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::DivMod::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::DivMod::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::DivMod::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::DivMod::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_div_mod.png b/docs/build/html/classmlx_1_1core_1_1_div_mod.png new file mode 100644 index 0000000000000000000000000000000000000000..11583cfa1c44bbfd64b891036eb337407e399283 GIT binary patch literal 536 zcmV+z0_XjSP)vTJr#LVva2S`&=-}Ys|Ns9r%~qrU000SeQchC<|NsC0|NsC0Hv*f~0004_ zNkle^^81qci+rc~bC6`K-$P&coP{!|)O#%x=ddPHe~a0c+h#nUEFrcylbJ2N*9M)9 zC=5pH>9^|pxsx!CkI-i7ZFZTE!lT-n?yvG`XxKZa(6wOBD}?&4JIVCeESJ!*k!&=L zw6E#md9R_q*{-1W)NS|LJa`fzo6#f?bT3JcQbI|xn-WTr-EK)LDIi7k8^dEI^LII$k5()s_lu!WZri21OHzgDRx+$Rm&`k*ifNn}C09;5! zWV9fvs$-Q0#Gtf7$vGil>S7&V`TYUl=gnusA-z%uhyYp(9c$+nmn$K+J;;n!> zp2r*1ykqFJUBRO2hx;v&wykw>u(iuO(>PP_c_@&m>RzXrM$nVz{XVqz&K_zwBfbH( zS3A!%tm$u{O(X4V8a}V>G^@92>e@8uUYplo%}G#-7u~C>Iz|aqRqdvPs;YL=T~HB` a!TJIyfM8qPkUH}K0000 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Divide Member List
+
+
+ +

This is the complete list of members for mlx::core::Divide, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
Divide(Stream stream)mlx::core::Divideinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Dividevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Dividevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Divideinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Dividevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Divideinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Divideinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Dividevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Dividevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_divide.html b/docs/build/html/classmlx_1_1core_1_1_divide.html new file mode 100644 index 000000000..be3840b28 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_divide.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Divide Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Divide Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Divide:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Divide (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Divide()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Divide::Divide (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Divide::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Divide::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Divide::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Divide::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Divide::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Divide::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Divide::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Divide::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_divide.png b/docs/build/html/classmlx_1_1core_1_1_divide.png new file mode 100644 index 0000000000000000000000000000000000000000..f3946b16d3cb369fc9a6e0c8ad4f458ae0854192 GIT binary patch literal 897 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUXo-U3d6^w7^zRmlrz~jcx zAN2J1{zvBb8(Ws-UCrAwOZ}Xe@Z&iWoByglpYyv}$8)b>)TAe>TU1VVJ8OGRns4%N z%Qe2wl^&Ob_bn*hd~EKkDDAcH&3wD2cpF~p*UVXeUe4y;%%C@wZ5!Tt@AV3@%-eFz z#4D@Es=PAu?$Sqo#{M%MC-vP4U45o#Mcb5;W$)H5esezCboZa2?9A@0sY{A}Y&j)d z61cWF+WzXKCEPJlp{M1mUvKF=zx7nIqGx*C#am`u>^g7n3B0Bfy144-cE3d`f29_x zdd4|xc2ioyO|#Z+c5lE?7#N-V(;~KTiI62?X|kvHJRgq`WA*gEn*GR zFED)&ieiX!;y&P=!C1km%b>rI^?@ov@)P@(qH{a#PfWTZ4Ksbt)=4TqXCBJGzU1Tb zskJlCJ5S20b9$}zDC^;}sbQZpgLER|O8?F{^lkh7rgrJ$J1lR{%i9@dm&tK&-;~Po zwVL^p&Ml6MjMI|abtmI^==E9tTGy{#)k@~OI$_c|nN_8tyT4xhSZcSn^wq;EWv``= zrBzrRil-F8hqXZHK= z{MLKIe-|v1jtUi&QV*PDcdomhx&Foc^S8s|uU6fE_5WwMuIGNg49`i^>$U>DX3b!C zlzD;uuPaqHyS}^%&ff6q;_BeF7M%=rIKu=j!0H{emrnWBoRW4;dSPnXGu16;&g|K- zHEzrNzTmxiGj|Hr`o1nY>%QgEj5MB^YOa^IN@V*U+fW&2Qxj*cKi~eDN$QfBH#?U{ zInN8#ojG;!4~gV1@6*$Z-#pp3-pDxn^PZ^GcWH}dcCNhkF!jutTQx!1Z$B-*c8Q@i zOfoaA?2+i&-28Q`ufAEkdD_gIy80QRy7h0?K2ysJzZQT0#o9+FO#L32#m|qsdsD@} zMCk09y2cAe#-D|v($eCbPM + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Equal Member List
+
+
+ +

This is the complete list of members for mlx::core::Equal, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
Equal(Stream stream, bool equal_nan=false)mlx::core::Equalinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Equalvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Equalvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Equalinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Equalvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Equalinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Equalinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Equalvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Equalvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_equal.html b/docs/build/html/classmlx_1_1core_1_1_equal.html new file mode 100644 index 000000000..0ed4a5122 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_equal.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Equal Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Equal Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Equal:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Equal (Stream stream, bool equal_nan=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Equal()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Equal::Equal (Stream stream,
bool equal_nan = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Equal::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Equal::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Equal::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Equal::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Equal::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Equal::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Equal::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Equal::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_equal.png b/docs/build/html/classmlx_1_1core_1_1_equal.png new file mode 100644 index 0000000000000000000000000000000000000000..7c77a8836e38f9537adaf24606d4538ed3e10cbd GIT binary patch literal 893 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU1o-U3d6^w7^#_~Nf;9*mb zS9$V#e#OEs$33GBIBu-Drm=Y1wZi3n%Z zcV7Sed%_a+8pqcM=M=0eUb|_}?UF}9=Z;431>Xppulu7c-D~61oo#F5Z$(b}Q?a1b zd-0;?i_gAIpB7qF&iDBT&!j7#S47!)hx4w?@V_0s_s#pQe7paIWao9KO;fQ?imGFu z8*E`O@ce=GKJ=d}A`leRn;U0dpSd+v^T@>8~YoXv>)xOT3FXMOLL zNlP?Vs(J>3T-vibG<)q~#t)i$437<$U3+S@r)J+Rp)n!5eTGOpDL3?XZ}= zBWvo>8mmI#XW_BHAo%tz`@>f^`SSDImd;-PQ$4o%y1w1}Wu{McFV$@^{;T@^!p~!c zF~x5`c-G}b>&jmab=JCdIrOzo+1s`MO6}Heef_XZaZk{HHrHGIsRy_3X^*O_pA+(W z!S<-#TLt4USmpdIHPNmAu*1Fk$2?$|oZoVdHMjgm)xNv$!zxynZC}p6WtNi68yC-C z7W)6#{=GQ=Ja_f_tiAbP-~S2qz7(&%RYmjEj;O8ItYjGGKVmYFOU{bbbFS)LS-Mr~ zVs^mVJDm)LIKu=j!2UOeX)f(yKYAv|BEpbwuPN=i2OZu4`xfe8{(5qCE4T(d)a<%O?Xh zbcUR+ezPj@%$nsd*RH*AyGK6zr&{^dEb)(@qwF4gZY|q;HKR6FaQ3wK)!ExSYV)_= zO;0O3XC3vrX`A)M_nUvOo3mdp==2iHi8F6X?*BQp{-@;1)HJcD_y4apGyW|KVyDKe e4f2xgv-c{QGlQv#-wl{i7(8A5T-G@yGywo^{j~f5 literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_erf-members.html b/docs/build/html/classmlx_1_1core_1_1_erf-members.html new file mode 100644 index 000000000..896816975 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Erf Member List
+
+
+ +

This is the complete list of members for mlx::core::Erf, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
Erf(Stream stream)mlx::core::Erfinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Erfvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Erfvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Erfinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Erfvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Erfinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Erfinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Erfvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Erfvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf.html b/docs/build/html/classmlx_1_1core_1_1_erf.html new file mode 100644 index 000000000..07e13bab9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Erf Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Erf Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Erf:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Erf (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Erf()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Erf::Erf (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Erf::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Erf::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Erf::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Erf::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Erf::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Erf::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Erf::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Erf::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf.png b/docs/build/html/classmlx_1_1core_1_1_erf.png new file mode 100644 index 0000000000000000000000000000000000000000..d21c1648a4e4c8f7a9b66bd2dd232bd7351a90a3 GIT binary patch literal 861 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B*!JY5_^Dj46+eK={g0*@QN zyz$&U|ApgUN3pohF3Xy3J)1Ax%k)Sx$L{)*md`&bcwfqBnWj?7sq6Xa(FLzbPt@<) z@7;R&PrAn?VZ94hk&pdK*G!drZ|1*hinpO$eDRH_`KCYW()~6*n|bI|f7#R}RdX`# zUe;J?eCy8FrN{L`%Xd}2;+eGN>6)n0X}+dQme}Uz=cevofA#Gr=e0|9XN69B_3-j4 z{|k;=zvcc7@t(w2kZG#B|LxklYIfIG^|z|Lt#n%XYPEGRFh;Eev~F#2Thw zVEP~w#SrJjeZV_|v4T^VL4P6Z1673NC-yBx=PGSZOfrF+{$}eWm7g;Y`(I!3d-?R* zS^s%FvwpT+^E`6#(4whw&pt1S+OX%<<7~d@`}JKQO;q;v!*AWLc-=Pl zxKw9jR&)0D)m7(@J8rewd-c{-w;IRQw?Czt=-#iqb5m{K@=*Ei^r%$l+rnYbUjIGI zxBlwAIWPOdf?LjcPFd3b`O*CY`hP!H-hLId_WRm@^Xsop`tq8(Ar$15Y2%TWwpcJK zEiKOJ^qDiKUq|U5`!BQR>-3p7i`QMR%3d3-eYf~t@^jCxElb7v*ITa9IwfT;e(ubf z#`xFsu1zX+`)d0BQdQ>9z;ExnvaZVQ&QdYHed$|R@zpbLdc3#HQd_TnJFVdQ&MBb@0Lg}~zW@LL literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html b/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html new file mode 100644 index 000000000..1b4e58ece --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf_inv-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::ErfInv Member List
+
+
+ +

This is the complete list of members for mlx::core::ErfInv, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
ErfInv(Stream stream)mlx::core::ErfInvinlineexplicit
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::ErfInvvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::ErfInvvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::ErfInvinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::ErfInvvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::ErfInvinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::ErfInvinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::ErfInvvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::ErfInvvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv.html b/docs/build/html/classmlx_1_1core_1_1_erf_inv.html new file mode 100644 index 000000000..92b8c6bf2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_erf_inv.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::ErfInv Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::ErfInv Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::ErfInv:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ErfInv (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ErfInv()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::ErfInv::ErfInv (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ErfInv::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::ErfInv::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::ErfInv::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ErfInv::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::ErfInv::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::ErfInv::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::ErfInv::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::ErfInv::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_erf_inv.png b/docs/build/html/classmlx_1_1core_1_1_erf_inv.png new file mode 100644 index 0000000000000000000000000000000000000000..2ed64aaf6fd6957d8b3f8aa8ecdc8419a58274fd GIT binary patch literal 880 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-aJzX3_Dj46+eckt3frqV~ zf90p&_8&7ZPhn8(jo!LjX05i{PRaD7GwetK@d zpYiVNA3K$oE&3U@Gavhvu9+%#f7_f8&Fe`veBL+0=Ij3Wmg2eb>CU#P-DT63RM}kI zwKVWj!j_${qx(Zw&ExreMN(zu^T@5QbYe}HEV0ea|GLS3z3H1z&TF^o&WfD$>fz;8 z4OkL@+WhY<#QYP6wg_`Fq6MUc~1JH^U(kKlH2)H z{#Mn=PrC9+ILcqCPfcv?7lU}eRh+Mv+W~{%+gH^ezZUfGPQQI=?e|yavCXGl=ayfN z`lOa=e{1H~wXc@GZZ3=|&i}EjDldAi)w0!XQ?@K$eJwEO<=TJ6c5Ao3epsfMwe%nN zq%HFOtoJ|hZGH8(Fkr1?{?@Xkg7Fumw$(`bp5~AHo+Ehw@~Z7eXNE;yC~Kek`EBL) zL%`^>wY`wlWpJZ&(wD-xKhi&(|JU5x_G{I(-(mmH@6Yo5RSJs7GpnZ-pS#TXK~s<6 zv*GeV&4+vm1*aZmbO_iDlIL}>GYX1-SMwo z|Nd^fwJvVq;rp>7?DCV^_^V58bGJ<2GHa`3 zc=na(_-k2Fb>ElO3C_=$`~H;qmQzjJtT)bC|M2+UubjHZo9C- + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Event Member List
+
+
+ +

This is the complete list of members for mlx::core::Event, including all inherited members.

+ + + + + + + + + + +
Event()mlx::core::Eventinline
Event(const Stream &steam)mlx::core::Event
raw_event()mlx::core::Eventinline
set_value(uint64_t v)mlx::core::Eventinline
signal()mlx::core::Event
stream()mlx::core::Eventinline
valid()mlx::core::Eventinline
value()mlx::core::Eventinline
wait()mlx::core::Event
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_event.html b/docs/build/html/classmlx_1_1core_1_1_event.html new file mode 100644 index 000000000..0b8a08aa0 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_event.html @@ -0,0 +1,320 @@ + + + + + + + +MLX: mlx::core::Event Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Event Class Reference
+
+
+ +

#include <event.h>

+ + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Event ()
 
 Event (const Stream &steam)
 
void wait ()
 
void signal ()
 
bool valid ()
 
uint64_t value ()
 
void set_value (uint64_t v)
 
const Streamstream ()
 
const std::shared_ptr< void > & raw_event ()
 
+

Constructor & Destructor Documentation

+ +

◆ Event() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Event::Event ()
+
+inline
+
+ +
+
+ +

◆ Event() [2/2]

+ +
+
+ + + + + + + +
mlx::core::Event::Event (const Stream & steam)
+
+ +
+
+

Member Function Documentation

+ +

◆ raw_event()

+ +
+
+ + + + + +
+ + + + + + + +
const std::shared_ptr< void > & mlx::core::Event::raw_event ()
+
+inline
+
+ +
+
+ +

◆ set_value()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Event::set_value (uint64_t v)
+
+inline
+
+ +
+
+ +

◆ signal()

+ +
+
+ + + + + + + +
void mlx::core::Event::signal ()
+
+ +
+
+ +

◆ stream()

+ +
+
+ + + + + +
+ + + + + + + +
const Stream & mlx::core::Event::stream ()
+
+inline
+
+ +
+
+ +

◆ valid()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Event::valid ()
+
+inline
+
+ +
+
+ +

◆ value()

+ +
+
+ + + + + +
+ + + + + + + +
uint64_t mlx::core::Event::value ()
+
+inline
+
+ +
+
+ +

◆ wait()

+ +
+
+ + + + + + + +
void mlx::core::Event::wait ()
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_exp-members.html b/docs/build/html/classmlx_1_1core_1_1_exp-members.html new file mode 100644 index 000000000..ff9c86910 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_exp-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Exp Member List
+
+
+ +

This is the complete list of members for mlx::core::Exp, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Exp(Stream stream)mlx::core::Expinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Expinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Expvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Expinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Expinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Expvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Expvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_exp.html b/docs/build/html/classmlx_1_1core_1_1_exp.html new file mode 100644 index 000000000..7a5acdfdf --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_exp.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Exp Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Exp Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Exp:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Exp (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Exp()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Exp::Exp (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Exp::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Exp::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Exp::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Exp::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Exp::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Exp::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Exp::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Exp::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_exp.png b/docs/build/html/classmlx_1_1core_1_1_exp.png new file mode 100644 index 0000000000000000000000000000000000000000..5072482beda1fdef29a6eb7ae79e18a80c2ce713 GIT binary patch literal 875 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-)JY5_^Dj46+eK={cf&km> z{#94@{1={Y*tKAWPf=n2>bW9&g<2QNww;YHI#+3PVv>o}LRHT=C+|snEVfKi`RV!H zKmXm8dK1s8gRD{h&vrP^&0hNT?lV8pBF|%0Guv~2-ekK!|BT``i_056EuE{nGS7d@ zyTmD1GIQVjy0**v>g2pRR!u5zE2pisI;AHabZO$YZQnN8ufMwXQ}EiQIMJj)d zL9T0=rc%kN>-ouOUD!EGe)a>~cQE!GpAohF_Jp|PEZ;EYD&AG1>VgOSGZ-s4bs6** zvOZAV!my`BtYP{ErVm0<3~^4}2fPuIPsKBy&iPz1K{e9^X1ZmTs^=$T{(39*yY_ou z1pmLNvhu0=mU%rEvm>@#%Cak8#d&)5y6KfSzI{vm@Fl_ionS8a_B*S?XR%G!FTLZN z-8Xwla(4Q)?Wf*E?p>UuVRcEm@+xJ7;dM?$)jKbMv!hvv%C=h|2uW z=4rL9^zil_v1@!JpcC|36nwTT<_r;W=q~;Z|VOS}^Q8 z%9~(w?8>gYZC|8Q{8gr%$_}~ zzdrM}>6<$9W_8Tk)Ul?XAVLdu?-bZU61q{*`I%t`pk(S8sha*W%>0g->Sdnj2Ri{2k?c z@_qK}4V}3s<0GQ4PJBJ*$ffx1ntPX4`z@Mz^Zw&_tu1X=jf_7FMx~|2Ih{Up=Jb-O eL0*SHvdc2hO + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Expm1 Member List
+
+
+ +

This is the complete list of members for mlx::core::Expm1, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expm1virtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Expm1virtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Expm1(Stream stream)mlx::core::Expm1inlineexplicit
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Expm1virtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Expm1inlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Expm1inlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Expm1virtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Expm1virtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_expm1.html b/docs/build/html/classmlx_1_1core_1_1_expm1.html new file mode 100644 index 000000000..7254a12ad --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_expm1.html @@ -0,0 +1,434 @@ + + + + + + + +MLX: mlx::core::Expm1 Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Expm1 Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Expm1:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Expm1 (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Expm1()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Expm1::Expm1 (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Expm1::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Expm1::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Expm1::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Expm1::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Expm1::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Expm1::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Expm1::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_expm1.png b/docs/build/html/classmlx_1_1core_1_1_expm1.png new file mode 100644 index 0000000000000000000000000000000000000000..da566929f85f62750ede4d52a3cd0d93330543f5 GIT binary patch literal 883 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B;FJY5_^Dj46+eK_fr0*@QN zys_WC|DE}3bz7%pm0V4Ie(a&o#?}I=Nxa=@mwqzHR*}!7L}9T&f1=n=9m2a z^33mV^yDS%H4d-O$&{}sUb|)A?K>--^iD?Tf8L;LfBV7f$(BdN(xt-oZP~1H^=;oZ z)1FCIYkOnwN0qOvy3`x&K4FRN_f4VhH&dOxtS()<7WeIU)Lkw6t5K&G%~V~PXSenj z%iNUL5!d*)s=Q>G9v1xR_bscON6Bk{MR;7An7-vy+cx!$U#v5IgD%esuX(BBb!k1! zh29ySlcpDLEjm|e!Eo;=Yl6+OE4yOazDNaUZ<4y09kAA-lcCOu`+#=_V+E%!gZ@I+ z2dY~b_OysKOuxYNK`4qL4oUKVW7s*%e)h;oZpuKvgH2u>Hc7?)^M~EVm%d)^w)>jz zJ1J`4!qW63H;*iu8uoc-ux~_M>E1Jkw{6#NUe5lzaC=7dj;C>ZFPZG9S@}3Gs`!qI zY_#>}eWAL$Y+m$dzbg%2^?cjSRmb0KQdE__nff{>XXVy+mU$6sF6e^hzxriE`gwg0~Tnb*EO z=l!`7_tlRagAT_?2n`)7w%7f)O3 zeDeMFKYycqchu&}nr+nmnG&|0KVvFyu7A?Y^;d24U;b0s=Y9Ijo98zBT|ln+`MCaS oT3WoymNRGe^sJs5FVdQ&MBb@0R70a)Bpeg literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html b/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html new file mode 100644 index 000000000..1b6c9e55c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_f_f_t-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::FFT Member List
+
+
+ +

This is the complete list of members for mlx::core::FFT, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::FFTvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::FFTvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)mlx::core::FFTinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::FFTvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::FFTvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::FFTinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::FFTvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::FFTvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t.html b/docs/build/html/classmlx_1_1core_1_1_f_f_t.html new file mode 100644 index 000000000..de279f0ed --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_f_f_t.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::FFT Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::FFT Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::FFT:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 FFT (Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ FFT()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::FFT::FFT (Stream stream,
const std::vector< size_t > & axes,
bool inverse,
bool real )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::FFT::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::FFT::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::FFT::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::FFT::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::FFT::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::FFT::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::FFT::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_f_f_t.png b/docs/build/html/classmlx_1_1core_1_1_f_f_t.png new file mode 100644 index 0000000000000000000000000000000000000000..aa05d735a5e91509583bd17df54ff4448f9e0141 GIT binary patch literal 847 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-SJY5_^Dj46+O+5cffya%X zzw_n)|3}NuolHoM>h;UJC*{{BUD6)(MceYZja-VSRPTjJD*6jmJ>z7eJSY8G@_M;> z`K$Z0RBRt|NAW+~;VhfI^!Kf2exgO5$Nn1eM}OYTx6f{-^P0kCiBG+)y@T$`WyBuV zSb6!@jjwBe>xKI7DtyH;Y0J~VZLc(D=XiKl-@0|{)V%dq*M4$dyHsaZ6F0b;x z;JEEu>faFWNqhyFrmFkbuDz>fcYRfVtIFF-*VQ>0)k5C;Ctp)BT^so*T0L-5om+ic7fD{q_E~!}%VoXe)<-H#KR9(6^cS)|P~F0?r$ww` z`UR#BLQxEHPTU8)GZ-t7B>yn3ol`8wze%MRZhG}v70Pz-@Pxj8uo&Gs?MG<<0evc)9+(x%Hm3>f&YAQNBV} zz)-O}*S)_{{qN5`xnD!qeqZ}<{{Ae_UzVVFq+SS336kQUH1lRg%e0v@r(ZBK{`_gt zweIX4b5|9``pukqQ+VH+t+Su;Jh(hdtg73heAR=~x86^$jF{S&pB-y{U{YyR>CV2D zdAjlYX05%I{C!eNTJ*JBtJc|HS=IgPNKR;V@szcnz4mQc_GJ6kt+$VC-g>HSoBGBt z{+gA>n}a!Zjg9perlvhp-E!v4o}SfHgS-xZ + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Floor Member List
+
+
+ +

This is the complete list of members for mlx::core::Floor, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Floorvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Floorvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Floor(Stream stream)mlx::core::Floorinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Floorinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Floorvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Floorinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Floorinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Floorvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Floorvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_floor.html b/docs/build/html/classmlx_1_1core_1_1_floor.html new file mode 100644 index 000000000..60f7e423d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_floor.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Floor Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Floor Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Floor:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Floor (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Floor()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Floor::Floor (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Floor::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Floor::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Floor::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Floor::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Floor::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Floor::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Floor::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Floor::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_floor.png b/docs/build/html/classmlx_1_1core_1_1_floor.png new file mode 100644 index 0000000000000000000000000000000000000000..2b602e649d0743547bceab9a20478cbbb48a191f GIT binary patch literal 866 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B+fJzX3_Dj46+eK={g0*@QN zyz$&U|AphXiaK8Owa%SxJ)1B6R944vfo=6CEuVi>@V=DMGEJqDQ`hs;qYGY>o~WOl zfA3aio&2OP%<>m@MLhN^T{rbwd9+&SDR;wb>s51BpBJ;a9~qQWxopEH_uXDWwt8Eh znRvbGF)Odlyu0+3pSk~B$4TE}LRX*Lw7P9d$&$Bg_rCeQbysZNs;?Q{X;YU}{n+wK zxFl$8a`gWds-EgQw#=G#|NFH>;reT<`dd|GE8SM-WK;{C-V=OH#dLAy(dmARRQ^gW zRP~H=@}9H@AMnm#tl-pT&|k>umdS?O*A=wWTi}mMLc~{VT2V z(&8HL{ZD0EUwtkNTI-j;wQQ;2{0pzP)kylDE|)9c+POaS>hVYFYd0-?%c=eO?Z3SC z?OE^7mAJ3gbb7Ahb?J8Hr}_ue>uTb*{aSJDb=d#&{##Z4T7g39*ws^(&mS^Wh?_Tj zln8!ZCVeT}XYI)>m-UWYAE_|?zzhrQ0riJ*?UW!X{z)@$0#nG$nbR*A8GruNb8WJI zfqjs_{_h>vvQO`c%Jj2w*PZ;z&|7!gdet{}_8Fy4XV3V3%-?t1*2hEF-g+F+o03+&?cS^GnKNTQrLAp0Sto42HdKDi + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Full Member List
+
+
+ +

This is the complete list of members for mlx::core::Full, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Fullvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Fullvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Full(Stream stream)mlx::core::Fullinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Fullinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Fullvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Fullinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Fullvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Fullvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_full.html b/docs/build/html/classmlx_1_1core_1_1_full.html new file mode 100644 index 000000000..7ad00b10e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_full.html @@ -0,0 +1,433 @@ + + + + + + + +MLX: mlx::core::Full Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Full Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Full:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Full (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Full()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Full::Full (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Full::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Full::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Full::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Full::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Full::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Full::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Full::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_full.png b/docs/build/html/classmlx_1_1core_1_1_full.png new file mode 100644 index 0000000000000000000000000000000000000000..51e2557809e41495c77b03b62b452248d0b92381 GIT binary patch literal 852 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B;7JzX3_Dj46+eK_fr0*@QN zys_WC|DE}3r!g&+D!=md+_8r~k!c@rY2nYXjd<)=x^C*W@@VzYQ`LssPD^dOp3cATUS#l_%4r)ux$pK0vX#q- zy{wUGaO=+3@Y{N!ydtV}+B&JAOB1(kue$j@n|Jr0kk|Lb&T4sn{g8Q8 z{l$`~RQJ`MxXD>ZJy|{oM=``XaUbx` zV65QOWzb*9`apFH!=4tghUpiWJ_sQse=@ z=KVI6tu@ZC9go~Rv}mf?=h7hE9s6IEMIQdPUH{>+xyD)sINaQR$g1MJ!Z|StZysH1G@!Cxb-*Rey{#&>G zaP`*R#Bjsr-ET=-QCi&%^$o_us1W*9sI$&#s=beEyQ5LR`Gz z=ZxUjMKdpD`>j2i<+k2)>mwDWA9#ZWN%9Zl+9^S^_$ST0nbR_D=FI6AjEp~jT6L{k zztB4DYgT-dfA-8s(|c>DZtHoy#H()oQDD&A`ZVRWF3_3N*z>YZu93U?RO+%+P1Nd- z79q+rBhTN9F3z~>`|oJh)t&yUzHYsD%l-1F+I(5FjZyzaLB90V`Mht<)kZ_(YNtOx yX3VVQ)HOENUznQqOm)kdGkbKRG?$vl{%5#yYVH*6l@CC + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Gather Member List
+
+
+ +

This is the complete list of members for mlx::core::Gather, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Gathervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Gathervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)mlx::core::Gatherinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Gathervirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Gathervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Gatherinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Gathervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Gathervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_gather.html b/docs/build/html/classmlx_1_1core_1_1_gather.html new file mode 100644 index 000000000..09cb2f1de --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_gather.html @@ -0,0 +1,442 @@ + + + + + + + +MLX: mlx::core::Gather Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Gather Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Gather:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Gather (Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Gather()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Gather::Gather (Stream stream,
const std::vector< int > & axes,
const std::vector< int > & slice_sizes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Gather::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Gather::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Gather::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Gather::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Gather::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Gather::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Gather::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_gather.png b/docs/build/html/classmlx_1_1core_1_1_gather.png new file mode 100644 index 0000000000000000000000000000000000000000..7840ba3e82ac51e0cf4e3a1471222031e98cc0e2 GIT binary patch literal 893 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU1o-U3d6^w7^<`zFz;9;BI zKC|cFe|@{-AGs6eXjm?P`fSQ3(}PZN+!^~P6`y)tzBc_jBKgXiiUBGu<+0{du{X{l?30eD+DaJ-uq$l1;W3 zcP$SLT7Grs>4l|JyR7fpq_TNtRferK@($+JeCe@wd(_SO*I&K+6udSyI%KoT)U(S& z=Pj6=CAdBQsCcGETW!q?@eW-VD2y8g#S&nZjd zmA9yz^v>{{H2uOP75&FqPv?9tV2C-&n_zS7%C5U@U!;PwH%VR04p?i^$x!FSeZV_| zv4T^VL4P6Z1Jx}Ids@UAre9$CAQZ(Aha~yGG3=aWKYQdPH@NAG!zQWN7yr0y`SRb( z)cL>q)jUn@7k*WFl=EoW)UeMxgMA|6O81^g{CzwB;N@PkcN25{xA#o-pVhQoKlNU0 z_PNu6;M{e$>{d*8SDnI_76iJ^pCwwQUQ&a^C*+ z`)^w7_bl`DU5i6S7ZwYtzAU=;NBW2J{~Eh(zgAs)9`^s-{w&X5rB2?H_VlctT72#> zV}Yo7!$*nW*JaX|vVGQ`%yL=pxb=|=(+`3{0}Z!7>`|IaP2^OJHv^N1v9bQb)U;=KWu6uZ&kT*;w^?axRQd7RWWjFb>z19-^|~h>e-Lfo5b + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Greater Member List
+
+
+ +

This is the complete list of members for mlx::core::Greater, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Greatervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Greatervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Greater(Stream stream)mlx::core::Greaterinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::Greaterinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Greatervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Greaterinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Greaterinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Greatervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Greatervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater.html b/docs/build/html/classmlx_1_1core_1_1_greater.html new file mode 100644 index 000000000..4e738bd6d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_greater.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Greater Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Greater Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Greater:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Greater (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Greater()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Greater::Greater (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Greater::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Greater::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Greater::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Greater::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Greater::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Greater::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Greater::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Greater::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater.png b/docs/build/html/classmlx_1_1core_1_1_greater.png new file mode 100644 index 0000000000000000000000000000000000000000..ed485df38fc62e9ab723449fd9db6bac6392be4e GIT binary patch literal 910 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVYo-U3d6^w7^zU^CVz~h>q zx8&8n|HAqK9Ou}t$d;yUlIky-<{(pCFzIdllR2Mj6ed5(X%X|R;?$k=<lg3z{(sZs6YKp2rN+nSzKYV8y}zy0=hR+_wCz^?Tcjw252JE>HEp z;JWo!?%$wp75#$D(CB@4qwgy1ySC|glS=iyi(1@K^N(DTpJF;ObhX__Z}o{P^}QFA zJoOj4PKtBtR@r-I_0-~Xml+x6KVoF~Y`FZ|W5cD_=R|d1?TMS5b<~rE!G0ktL!1*g z!=4r~h6+wyh7Uqf3=dSdFdXpCU~HIvfr&vMS;0@{D9h(I@+VX*rC}C)iRx7OoAYpc z@ugqpr}llB->Z^c=k!|bQC8ALulAqU428e>+v-K%`1w_}LNep=yqo7rejEQ@rn=+K zjXle+O*5ZZyE^@r`?ICL4_K_LGOL+9H#+?EF_Wtb!qu~~VvkPSd$sP|@~yjUWAs}> zv-T@YthE;VbH@70rst0rMK$jYy*;(}N87c6)6>Ipe;$1OPV8fL)VGK|mq3BC;gbFF z*`ju<@6LJI7#7-c&{O4P|L14_8|}Y5{+zlx|LU(@SO0$w_MNofFJt18>6KfH&Q)45 zJP5X7NU(c$rOIa6ms!Et8)se24q9u`$?!lOSBPK@vU&&YrBj|UdrnGI-lB5SJHvC5 z`1EVHpS_Npv_vj_RsO|WkB(2aTQK#S%Igo;3^l(;$xPQhGv)ZVa|_=V{R`gY6?A;< zHKUIK-ji~lZ8|o?{kQMQ3ohoLRWgs%Zp~S=PRcXrwB?4Q=FYv-RAN79>h3-{e{R`} z%F1<#h0~S@#_8@h6Th0Z<+qn@SX$kke%b<{P***PKY~|+8 zd*=d7{U@;<>F-ym2BosC2i7L>&rx#5P^1A$y|Nrti!rG^&TLN + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::GreaterEqual Member List
+
+
+ +

This is the complete list of members for mlx::core::GreaterEqual, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::GreaterEqualvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::GreaterEqualvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
GreaterEqual(Stream stream)mlx::core::GreaterEqualinlineexplicit
is_equivalent(const Primitive &other) const overridemlx::core::GreaterEqualinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::GreaterEqualvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::GreaterEqualinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::GreaterEqualinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::GreaterEqualvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::GreaterEqualvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater_equal.html b/docs/build/html/classmlx_1_1core_1_1_greater_equal.html new file mode 100644 index 000000000..16aec8e67 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_greater_equal.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::GreaterEqual Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::GreaterEqual Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::GreaterEqual:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 GreaterEqual (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ GreaterEqual()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::GreaterEqual::GreaterEqual (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::GreaterEqual::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::GreaterEqual::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::GreaterEqual::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::GreaterEqual::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::GreaterEqual::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::GreaterEqual::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::GreaterEqual::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::GreaterEqual::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_greater_equal.png b/docs/build/html/classmlx_1_1core_1_1_greater_equal.png new file mode 100644 index 0000000000000000000000000000000000000000..3b6862e595a2eaaf7272cbf9f62d73d8e0bfd547 GIT binary patch literal 945 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GURo-U3d6^w7^-p+fYz~gEz zo%!Vd|3{N2u$<9-JkvOL<*Tl%F$U+De13OHB**Z(Op4*XuH>n`&UI24PgJAI+m^rS zCS|+rr9AB(a&O5yHY51#)n@B9iin&qwCDR05!Lfz)8tSp`(SIrF8kN#xpn zYvyN7e8Mt0Ec8=uZt3hFG2ZgrlipOiYUgLv3Vr^Y5Xl_5HLELsuFIr9hB4hLYj=ow zMiuH#+Ola~SdXC|J3~zuJHwjev!eQsPm0UV$_>-5(p~jbU6A3;4l#y^LS2RpkD?e7 zj&5Z*(3Q>DAbOpNL3n^Ev=1$sMjX6^8ujAFP zY>)a?;}e!05WO{4?}+%8Hr4sjp`V)1-_3ng9rY{1=JF}&sQDXe;(yJS)mwdc&dbKI ztsKphJtoONKf1s1{+IL5b64wMt-Al}`=8ZnC*_ZC^|*9uSCrwg9Wo3bH02mJ$Yp2Q z%6V6~MRQm;@P8_cr6_hqfty5J>I=a=xWlhi4xUjqY z{T&|upMP1TD=FE&`kKBa=ubt!YE>np#ATx2=KZz2yJAOi({m6oq}Uz)1-^!mk> zLG%1OMGUUqV$$ELuxR$KH-@jS{+wWJaCfYUD13 zSHo^j*7~x4YsphN^W5-%b#)7u1!ap~O^>q5*%Z}Vwm-;s-pXmU!mB6cXRi)Ey?V98 zm9V*=bRL&)`P6nzJ@9%^-KqG{i!XUz7XEx6Hf`6Z)xp0u*FVkQ_a;h7Nwhnwv%{xA nS5QzJbP0l+XkKR(`(x literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse-members.html b/docs/build/html/classmlx_1_1core_1_1_inverse-members.html new file mode 100644 index 000000000..1fb2c948d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_inverse-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Inverse Member List
+
+
+ +

This is the complete list of members for mlx::core::Inverse, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &output) overridemlx::core::Inversevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &output) overridemlx::core::Inversevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
Inverse(Stream stream)mlx::core::Inverseinlineexplicit
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Inverseinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Inversevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse.html b/docs/build/html/classmlx_1_1core_1_1_inverse.html new file mode 100644 index 000000000..ec23a1051 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_inverse.html @@ -0,0 +1,323 @@ + + + + + + + +MLX: mlx::core::Inverse Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Inverse Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Inverse:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Inverse (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &output) override
 
void eval_gpu (const std::vector< array > &inputs, array &output) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Inverse()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Inverse::Inverse (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Inverse::eval_cpu (const std::vector< array > & inputs,
array & output )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Inverse::eval_gpu (const std::vector< array > & inputs,
array & output )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Inverse::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Inverse::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_inverse.png b/docs/build/html/classmlx_1_1core_1_1_inverse.png new file mode 100644 index 0000000000000000000000000000000000000000..c59ec21c0952e25cd20892230ab773cbd6e43350 GIT binary patch literal 884 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B;FJzX3_Dj46+y<60*AmGL? zZ#=X1zxq0) zbe-xJm6P5Xo|C56Z7n*t)0*MlQQigizpm`MyX(uV;Oq^rF0KwcNxes^+UA%M>gA20gpGW!WB2 zt8G;&m8)lz)LrOX{mwL9>-e?Rp~v2YIeS@cie9_<#@21$)_=}h>&$yyI#BEFeFe|s z+j4$P+rBb%UjO8)Sy#exw~IXzjygLjK05SM@5^;&A5GVOim!>Ho2VvBJz!^$E*`qXV;!Q`4RuUH)~+ zukZK2=w~f0Z2I`~ismxux`npAD@7vuwv*y@ZQ~%PfQM(J5 zUQ2w?9Q^R=DzASj*-@*v{?)9A&lP1~StxY<-^L}+gH|c+uM|3armp>hk@07tsI;^= ir_*Q7c<%=}S2q6W-8p&9o(-UU#o+1c=d#Wzp$P!vbGG^b literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_less-members.html b/docs/build/html/classmlx_1_1core_1_1_less-members.html new file mode 100644 index 000000000..e3f748124 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_less-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Less Member List
+
+
+ +

This is the complete list of members for mlx::core::Less, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Lessvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Lessvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Lessinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Lessvirtual
Less(Stream stream)mlx::core::Lessinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Lessinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Lessinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Lessvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Lessvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less.html b/docs/build/html/classmlx_1_1core_1_1_less.html new file mode 100644 index 000000000..076e7bfb9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_less.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Less Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Less Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Less:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Less (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Less()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Less::Less (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Less::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Less::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Less::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Less::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Less::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Less::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Less::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Less::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less.png b/docs/build/html/classmlx_1_1core_1_1_less.png new file mode 100644 index 0000000000000000000000000000000000000000..5fde4667de943f1ce94b7377e7b1d10bf4c4687d GIT binary patch literal 867 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-~JY5_^Dj46+eb{$ML4d8j z|Hzx)_8+g_3d-2H&dpYDdz)ouWJ1Hy-`?je``IHWxhZc^Iq99@IZ3=dY?6xo%U>_Q z&V99A$&>#_!!=L&?*ZSpUfQ*7XAtksoMXRC+jDE}v7y;(4ohTl80r-FGKU>f5q&>-4|-!pu*<&(h7j_+?Uv`24Wi z=67$d8LpMj^1LJ)6XkpAeevspj{dE$vK2ku=Uvodzjg0O(EFV1NwY5Ld|WzP!*jnI z(0S7@Oj6NbsOlMa?CL4Y=MNbw#KjvvN(8?ylDw4dv-V_`%X-JHkJOlcaOyJXFJyh7 zx`km+i&(?-3rrt`q8Q?wxDR+|FjgQ*{$XA_r&x}ElS(h#boaF?o^{VYT($!SLa*J` zeAh`)`xcf86|PUd=(YN1%t|XA@7Vp%HspR)t;o7@{I1*Ww&;6TjeQS3T_{(6dDd~? zOB=V?z7D^-#M{|^UFG$;E8pFI7bd&;Hm~N(6M5U;9i4XPmj3g*QSCB${ENbFtyl0o zzDe$fSzh?+ee&gHt0zaVjrBaDzGatYyw~JUb6=k0`)IoMSCq{qP{l=0%cH8CT z9vv@Qy{l;M?1)$16_$JZ@3 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LessEqual Member List
+
+
+ +

This is the complete list of members for mlx::core::LessEqual, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LessEqualvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LessEqualvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LessEqualinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LessEqualvirtual
LessEqual(Stream stream)mlx::core::LessEqualinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LessEqualinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LessEqualinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LessEqualvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LessEqualvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less_equal.html b/docs/build/html/classmlx_1_1core_1_1_less_equal.html new file mode 100644 index 000000000..ea0474c76 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_less_equal.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LessEqual Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::LessEqual Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LessEqual:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LessEqual (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LessEqual()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LessEqual::LessEqual (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LessEqual::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LessEqual::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LessEqual::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LessEqual::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LessEqual::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LessEqual::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LessEqual::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LessEqual::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_less_equal.png b/docs/build/html/classmlx_1_1core_1_1_less_equal.png new file mode 100644 index 0000000000000000000000000000000000000000..861844408310aba3ff8e8f7c874c0ab2c3ded067 GIT binary patch literal 926 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVgo-U3d6^w7^J}g>nz{7TZ ze!5ZBe|0;j12+K2ug-Ok#cljfKF zf4O$d2Z{aUwNNhrxxAp`I^jk`&&BqzVl}^x7h?o{JgYoQb@mg z)OQoFS3PFswb`~yvzERy-D z!(8Z{;W=q~;nt#al@<(kM|l%$j$PSxx9y8maP}sti`fBdU$`*XFJyh7x`km+i&(?- z3rrt`q8Q?wxDR+|FjjEtGUy{o{&&zu^Z)5UkpC_JyqpCLgVgz;Fo@c>@T ztAEx6>PEzs{+)4n+jjkfmyg)Jo0#jp{m$z7Calx-OYiw+pPRkpaYp#H?W@}Lg#Ye% z7Wdur>#efwuV&oM^YC;p+g&Oaot6FHYW_8=D+Sd8TQB|Nom3+8x-q|UcT{QpoHt?1 zw?*v^>d?<9_1@RB?1^=Mby>&z%&WPDOGEj#yxmZ=_w0MWira5v{mZw^@_kXVNTqg; z_G$fpVvoCs+)UZl8XN0tfzB67ckUxNgt@_Sy{T9_oeBwYbmCS zt}lqX*vU|bJv=Z2>wjaImu4^f$ul`7PN&bD@y?hzv$AsC*F~Qye=Od6=1kw0yn64` zXWp=E58Ks#@BMl4tB0QUUhlVDbF1a+s+OtuL)%ne`dF8*G+i5WbyB45s{Wsix7R*B zZIwRl)DqxY9)MjHE-t^Q{5#?5H6?cATow?F+J6n$GTXQTb1(CMG7_TPCeleBv4 z%2|=ddRt$Ws_(vfdENS3&Ud!_dn7P>(*5h#o+{3}9u_@4EsfW2ZP9{U^=o&(yg2pr z3*X!b*__aCOZCsLk=*|w{NCpS7o=i + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Load Member List
+
+
+ +

This is the complete list of members for mlx::core::Load, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Loadvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Loadvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)mlx::core::Loadinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Loadinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_load.html b/docs/build/html/classmlx_1_1core_1_1_load.html new file mode 100644 index 000000000..f37a2c781 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_load.html @@ -0,0 +1,303 @@ + + + + + + + +MLX: mlx::core::Load Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Load Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Load:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Load (Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Load()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Load::Load (Stream stream,
std::shared_ptr< io::Reader > reader,
size_t offset,
bool swap_endianness = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Load::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Load::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Load::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_load.png b/docs/build/html/classmlx_1_1core_1_1_load.png new file mode 100644 index 0000000000000000000000000000000000000000..cb43b85d486bc02ea789d571682796299adfc197 GIT binary patch literal 872 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-4JzX3_Dj46+eZ6kAf`FU2 zyr=Q||3}={iaX>i_uXx7ozA1Zi0AR6YyVWA&pF<#t>E)M(pnp|UgQ1#PzLX!h9z-}o2seNxCY_NZbh z)l#pyyKA$(=i&Sv-CaH_r0c$Nf8S0$44|r!VR&eSv=r3e_ zpt^-&Pm5T?^b1TMgrXSYkR<;%hMlwQXOEoZrVR8u*yPP&lT_@ze;l^U+*fsT`LDem zL1+ClY&&+m+L68VY|Tm$*=q~)wyVC8_$y~)x~1@Yj{B~!u|Jn;R@7^JmW%Sfr!v?3 z?&g}%nv&cX$3x%mir4ymZFA@`lg)~%b2HOkYn9!Pep_I-w)EEFuTE3E-f~PT(YxIw zU%5N#>CYWA)}Fh!>g~7QA5CH9r{6EV^=a;T+o&Vox137;=#w3P;aAJsuZ#CaK1|*= z+xYIqRcZ^1g;ZZY-t$BHhjYDMespc<>)O@(KZomj?)Rf^2zhDtvY$MYW8!rB%o*>D znKLUZJzpOcjA^q!|G9nZr8((*z=Uz_<@cp~bQ8ZMs^&-HbL*i>tQfFy5SPx7IX1{H^!p@-2N^cg|VM{Wk67*)!`Z z16SYr^!8rWo{sx>qK{5|J?F^YeYJJ9R)^jN$^W}+Wc*ta?h;O2V`KftSxcu}D)_@Z Ych_8PB@qu6V4h&`boFyt=akR{01ghOZvX%Q literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_log-members.html b/docs/build/html/classmlx_1_1core_1_1_log-members.html new file mode 100644 index 000000000..b66e45b93 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log-members.html @@ -0,0 +1,119 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Log Member List
+
+
+ +

This is the complete list of members for mlx::core::Log, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Base enum namemlx::core::Log
device()mlx::core::Primitiveinline
e enum valuemlx::core::Log
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Logvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Logvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Loginlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Logvirtual
Log(Stream stream, Base base)mlx::core::Loginlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Loginlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Loginlinevirtual
stream()mlx::core::Primitiveinline
ten enum valuemlx::core::Log
two enum valuemlx::core::Log
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Logvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Logvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log.html b/docs/build/html/classmlx_1_1core_1_1_log.html new file mode 100644 index 000000000..0728bb815 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log.html @@ -0,0 +1,496 @@ + + + + + + + +MLX: mlx::core::Log Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Log Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Log:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  Base { two +, ten +, e + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Log (Stream stream, Base base)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ Base

+ +
+
+ + + + +
enum mlx::core::Log::Base
+
+ + + + +
Enumerator
two 
ten 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Log()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Log::Log (Stream stream,
Base base )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Log::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Log::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Log::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Log::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log.png b/docs/build/html/classmlx_1_1core_1_1_log.png new file mode 100644 index 0000000000000000000000000000000000000000..cc9ba7d8505ae19adcd566c6d03b83559df66dd3 GIT binary patch literal 866 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B+fJzX3_Dj46+eLHEh0*@O% zzwz8V|3&v_{be|6_V$L}?>wtH6TLdzIo`$>o!e=DV$v1qg{q!$PTrIDSZtZ3^3(J2 za?|n`_h+lvK4gyKf3_odZuZircbEB1EQ;%SYSfzhdNXUka&VXKd4X+pE7Lu-ZZoen z?42aF_FU|}sIrw3na?zhJuc02Umcoub8Xv{k|pcb?tSxp>#WeaRi@KjO}&F|)sqF%RC4Nie!8?i?3`se`+?9q48Iopuf4t2d%fLOw$*Zbt*&-W=6ImKg<(&NSi|%S zOdo`z7~-6`4|r!VR&eSv=r3e_po)+szdwW{mzHqxtj=23(D$Unw=1=NdJ1_E{mYfy; zmdC5A&#F%?etlIdxoowfs_f0R*CO9UMgP9OGwOA~;jeB}yndT|T-vBBz3*vm*3{>Z zQ?^QN4ZXd#^+!|K<|o&~rq(^J*pyY^Fr&1F!4Bxcu3-+%jI+pXQkcQ33GTUaQh z`m!kQ5AUB#^UvQ7lfPPZ|JChtgXJs2vrlYap}KjK=ybznJ51{%kG*>vwf5BG4N@o1d^0(= zTG4oOu-yXFx;lsK+p{lUwYvCkNm%Ho-?yrEb;_?^`{;zJ-=kacI?wXwTXv?UUB9sK z+S9G6X=RL0_1Dh4srmn8{MMN_#W{41jrIR#0o_*qhncN?Zg1;LNo` + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Log1p Member List
+
+
+ +

This is the complete list of members for mlx::core::Log1p, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Log1pvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Log1pvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Log1pvirtual
Log1p(Stream stream)mlx::core::Log1pinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Log1pinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Log1pinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Log1pvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Log1pvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log1p.html b/docs/build/html/classmlx_1_1core_1_1_log1p.html new file mode 100644 index 000000000..0eeaf6c9e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log1p.html @@ -0,0 +1,434 @@ + + + + + + + +MLX: mlx::core::Log1p Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Log1p Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Log1p:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Log1p (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Log1p()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Log1p::Log1p (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log1p::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Log1p::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log1p::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Log1p::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Log1p::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Log1p::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Log1p::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log1p.png b/docs/build/html/classmlx_1_1core_1_1_log1p.png new file mode 100644 index 0000000000000000000000000000000000000000..fc2853680b27706d0125a21079bf0595a721d1e1 GIT binary patch literal 884 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B;FJzX3_Dj46+eLZQlf&km> z{#94@{1=vgtksk>XUQa;^HGp{PkuQe)Nv9nRNKUdSxbZ@4vYyW-wsS7xno`yv&h?Wfc*{{qtop(utp zC+-8@8H^R2x(xaYSs$owVc63m)-W9*d0PI$Da+>{6}&HHz)UZ?;x*|>I$!+_^}F_a z-vs|%?h$m_e2bjRoKl^rOw;+=q3usYub-CLwzrmZp47KO`Y_sh1unsGPB!_&QNbE#PLw`>1O@2xGparmv%6|djs z9+x(*;okT3ZPwN2k5jf;Rj=NvE&Y*m?Yn9EOK*Of`%q5zQFheVjXN%Bao?J^A?W_C za_v25zrFJ*-!jYd#T*yUDr^6LY;`ZsKTloVpS3sd>;FHY-j~)3Q#FLVG?#uDqCZ0%pM-~((Zt5%DRxly3T8+Tdt8RT|0Hds^1~| zFYUfEt5{`o@=CqW7Pl8aP0RhQIrZALytHeNtF~qJAI?!RzI{1U-}mY3?x^qQ^)s#( z*BeEv-zopr0GOt*YXATM literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html b/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html new file mode 100644 index 000000000..3c05b01b8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log_add_exp-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogAddExp Member List
+
+
+ +

This is the complete list of members for mlx::core::LogAddExp, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogAddExpvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogAddExpvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogAddExpinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogAddExpvirtual
LogAddExp(Stream stream)mlx::core::LogAddExpinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogAddExpinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogAddExpinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogAddExpvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogAddExpvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html new file mode 100644 index 000000000..1cc6326ec --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogAddExp Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::LogAddExp Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogAddExp:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogAddExp (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogAddExp()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogAddExp::LogAddExp (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogAddExp::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogAddExp::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogAddExp::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogAddExp::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogAddExp::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogAddExp::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogAddExp::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogAddExp::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_log_add_exp.png b/docs/build/html/classmlx_1_1core_1_1_log_add_exp.png new file mode 100644 index 0000000000000000000000000000000000000000..28cb8ab01bfe36d2661453a2d8e022c8cb461383 GIT binary patch literal 943 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUxo-U3d6^w7^zFxQ5iib@- ze#xo#|Btj6adNV}T6g#A+{JITlr=2UJl=fMw|s6Rm*Oeadts7_{z6sHIGHHVNq?3+ zUv6G*72l?^_ksA9yk|RtW3!jOz58sk*3(@Qw-fnpUr*<+yEilVP35u;pO(&5T`6ZC zb>775RgYPDZRYKzul&sY=Q>X6yR$ZI?$ZrYL6>H3+t#<~{`FU3KbKyM5<9Eo`SnBQ zRoRk_QO`HktyJ|C-?3%ZB>QXEGKJ%_ujVRxrr*1$C4H+b$@0BU_M}-C_dM#i3Y_#$ z3gkK`?@4=Frm0j)hHF2c)6V=L*oNU)vd`M%$sX$sx2mlc+{<>g%a7%Qa1=wF6ZZk{ z48{shT?YMytPfPTFzjg&YnXn4>4OkL@+bS2qH~otCnlM|O`o%MlFHAShx4BU!(e(X zC=9mNIKNhVl$CVRYxd8YAm51iQtLB`zi&@}@bZ{p-pOstqt{>cpVf4Fy4Cx2rq6vZ zRc2&!WD2e!Ww+_0^2SSstG0e5*@EqqAQBu$q6(>T1Dm!L64*@=hv| zyWRG_a(7f|eaxD$<=dim2X*9UlzQ*$S@y)bfA{Sp^E0pJe)L+MlW{xoY0dun8xL!@ z^4s3Mm?gIGa_6Kkh4=pO{<$>&eE-_I(AUqy{?Ff^<@xJ1bwen~E7Qg!Ep4%2R9ae` z)9Ev3PH(+7d)NBQYtk<#W||pq)~<`zerv9!YrHwyFRFOn^R1>2=ap9Of4gDdSI^Zg zxtFI)zyAJo1GgUI^QqR^X9G{Zu~~I^PUikAc3FD^oVL$h&42aK)$qKq)!$x(wyADz zJNGvwO}6V{RCMpJ<6B-G+I!*Fy!%<9(Vv;usyWwh)!Q=r`FyVj8#?o|OjYYQTi34L zn=`tuCHTvzEG~0bE zuKu|+Yv$b#zCah-D{MJ)2AHB|&YbS4y>v?3WBc0svwWJbE!zOhS`419elF{r5}E+& C-_iyE literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html new file mode 100644 index 000000000..b96dfe648 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_and-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogicalAnd Member List
+
+
+ +

This is the complete list of members for mlx::core::LogicalAnd, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalAndvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalAndvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogicalAndinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogicalAndvirtual
LogicalAnd(Stream stream)mlx::core::LogicalAndinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogicalAndinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogicalAndinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogicalAndvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogicalAndvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and.html b/docs/build/html/classmlx_1_1core_1_1_logical_and.html new file mode 100644 index 000000000..1ea446dab --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_and.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogicalAnd Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::LogicalAnd Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogicalAnd:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogicalAnd (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogicalAnd()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogicalAnd::LogicalAnd (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalAnd::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalAnd::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogicalAnd::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalAnd::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogicalAnd::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogicalAnd::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalAnd::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogicalAnd::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_and.png b/docs/build/html/classmlx_1_1core_1_1_logical_and.png new file mode 100644 index 0000000000000000000000000000000000000000..65d5f0bb4faeb52263141ede1bd0892ba0dcebdc GIT binary patch literal 930 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GV=o-U3d6^w7^KAiMQfya$s z-q`Qn|IYl}Z0^=;FLv#IbH`%N#!CYJiWm1kne(|~f@-FTllP=OEz?vg3o}$bKTUqO z{@$&ZRg#mwwB%pd74g`wblp_B^JeiPn%B=(>^YnG``?>`->r+gY|o2qt6dT9sdc>d znr-hStHtMH{cQbK?znt5IBmibsD=@=DO7@$g(CX zm3{7;u;^?4TUB1NOb-kFbo^Ep-^ZO{H^nAh`P6l-Y(c4d-u1b;zCos|+-p_~tG>L> zq3iicC~DFZ)h#L~Pw$Abd~PGdF#i#ofn0Ld-8|>2-j$}2y)R5xgq0{Y%)h|&K`4qL z&WZbgcLrkxr!Is3Le>YWTNw7Vh&4<{NS>}=FtzyHW5;PrF2YP-dev*vlXSlQF~Bg0 z`vwYwkk7(V;)3bi=4-zg%-8Q?eZAQ3-m?vVUv7W!@|a=X$!*J{pP$N~)wG>I^>wfz_y;ZgS)QrP<9-g}GcdfebX1)GlG5=at)(*RlpqD>b zJhP5JKJa}{?AlHCer2nJqqp7->(I|A^{$f+Jz2i*-7A6nm$hy`nh6S!_T1~g{~qRi zzq|Iu%f7I!oOAuAEb0IJDE|Tf-=BMOzpjmXAN_ak{w&X5rJ%^{Sq=2sVa5-dc?^#Y zmtA{nw&ePZsL5A5;+nHu)tP?a3=^~f`@VT(7pA5?d+PMs^?qK3 z@lKyPKL2M-n;CgjW@;3#v9|ka)Ac+3*L83IeR;j-PqTL2>!Htdb4~}Z)84vlyV8|2 z-!5L9wRm!*b=P*w$97xmD(*zP<=@>_M5nx8$Dzj{-&`mNq;g^O&%w*!;Q-bXuZj%>d5Dr0Y0wCm|otBr5gE!Cdx zFFtq2S=rssTkq={Z`S|VdglJptJ)IA#_IRCML*SFJM(7tpS}CDj5ps`+H&R$D2Rf* c4u9ltdo*k7_gVVvz}&>(>FVdQ&MBb@0JL(_#Q*>R literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html b/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html new file mode 100644 index 000000000..1b749bde7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_not-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogicalNot Member List
+
+
+ +

This is the complete list of members for mlx::core::LogicalNot, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalNotvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalNotvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogicalNotinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogicalNotvirtual
LogicalNot(Stream stream)mlx::core::LogicalNotinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogicalNotinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogicalNotinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogicalNotvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogicalNotvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not.html b/docs/build/html/classmlx_1_1core_1_1_logical_not.html new file mode 100644 index 000000000..a4650332d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_not.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogicalNot Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::LogicalNot Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogicalNot:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogicalNot (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogicalNot()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogicalNot::LogicalNot (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalNot::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalNot::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogicalNot::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalNot::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogicalNot::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogicalNot::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalNot::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogicalNot::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_not.png b/docs/build/html/classmlx_1_1core_1_1_logical_not.png new file mode 100644 index 0000000000000000000000000000000000000000..51f652755cbd67a1549845fe366564f6d1c64051 GIT binary patch literal 918 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVQo-U3d6^w7^zFzmpfQL;z zUggR6`bWB~QOkY16Cci5b4_D$Tg)BLM+f4w)|T3=m}DZgQq?ms$a~TvmMvW>M>Vhe z=f4ZtFX~x!NPLUmvleIB?5W3ZJ?rAGyx(&?S#I0A^Hx7@8Mj9~_BlLF|E-SazBw7O zhgYl!zI9{Y%R;Yp%cYOIPgtUxzWeFooHA#vTbG|*3;Xtb>#oSU6{Y^&Y135VlcN5y z$-cQ}w|0J(=OsqPwQ5hd@7!vANV>NaSsWED`| z%CKmOn8TDSOdXS=7y^U16}++-U63UIut#mZW+fNmDb))z`CF9hq(^5C`#+!ZdbxP* zs{36kTmQ6NJK3?gvpcME#(7cmhx@L8f?!>pq)lng@wlctdsjw$neiT+?g6-TcApwx~#E zyd3+!V{fy5{w~}QJ@-!5+i5N}j;m*X^3Dv}_j%W5=NkXh@!jcBryIAK=d4`+ceZd` zd3Et}>!`D8C3%ZfYUha8GuOX3|2%i~`mDYAU*G=;^}e)za?~V|)R?tV*|K~M{)LPZ z{f1lhrYr6}c4gKYw=Ys5+WATjb8v + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::LogicalOr Member List
+
+
+ +

This is the complete list of members for mlx::core::LogicalOr, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalOrvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::LogicalOrvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::LogicalOrinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::LogicalOrvirtual
LogicalOr(Stream stream)mlx::core::LogicalOrinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::LogicalOrinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::LogicalOrinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::LogicalOrvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::LogicalOrvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_or.html b/docs/build/html/classmlx_1_1core_1_1_logical_or.html new file mode 100644 index 000000000..205a8832d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_logical_or.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::LogicalOr Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::LogicalOr Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::LogicalOr:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LogicalOr (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LogicalOr()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::LogicalOr::LogicalOr (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalOr::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::LogicalOr::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::LogicalOr::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalOr::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::LogicalOr::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::LogicalOr::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::LogicalOr::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::LogicalOr::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_logical_or.png b/docs/build/html/classmlx_1_1core_1_1_logical_or.png new file mode 100644 index 0000000000000000000000000000000000000000..79dcbcb6f02248a88ec94f481a4bd7b3eafc9650 GIT binary patch literal 920 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GW5o-U3d6^w7^zV2IXz{92< zzvR^W|3}z0rn4}W?#?w|{W#QZ>4O=ak8}P{Dn9quaoUo59J-#LgrX)rIl6GF%1M8z z|53&5uYXTi!d~O>I;qe4>dvsJYi6+`n$MGTxQTp!FwN2hfC#TIbYF2r>b5+>P)No$SmmYVct8c!~Hr@RvBs;G=3#k6bmQ%td zL2Hwv?XOH)A{-MHdis6!>kXa#TTdk`db-cM7-hc2uJd*2;x(SD+_oN>?zc$guhc?S z&p0RVNqbtRsZ_oU*M2@{H}iwoI}E=T`>(yd)_c9(R<_l0d#$c^P3Cx@zJ*~=i&(?- z3rrt`q8Q?wxDR+|FjjEtGUzX4eV~ev{3M_8bWU;o1l5)1Fq18_R6Rc#x5l6Id>pU+ zyP#HnQpjiCwcQ%f^UsYP=>WY_t7RZL*&EB_iPIUO{o|~c0URIk{UyFQmEBg1-J<+c(9RBJy#p}1I z%1euBy!StuZ9Vn5Flen?{?)QeP1G{%H zesP~4wLN-zTx}L_m~B<*Dp7U81O8aU12eFm+F$U}T>4L8=FG2d88c^Aa_Slz>+^4! zSCZEzpXX~UJ!wP)N>X=$R#S7OheS>qY4-D~`A^Xs3Vg{Jpwu8H_sw6$$*<*CZg zdqX!Q7A!J7qNh1I(z2`HaM_Nn^|$ZtG|U&jA!4yI>}BB!H>1tAz0cXZwMw_nO$+Ot zpFMT&Z;#h$OU6`?o}?XZ>Xli8ym#&+pNs zqMN_fx12e1?{N1``<11)EYs4i*KCQJzF$8rZTX+G^RtXM-&X+2@A + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Matmul Member List
+
+
+ +

This is the complete list of members for mlx::core::Matmul, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Matmulvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Matmulvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Matmulinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
Matmul(Stream stream)mlx::core::Matmulinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Matmulinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Matmulvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Matmulvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_matmul.html b/docs/build/html/classmlx_1_1core_1_1_matmul.html new file mode 100644 index 000000000..490ff158a --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_matmul.html @@ -0,0 +1,395 @@ + + + + + + + +MLX: mlx::core::Matmul Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Matmul Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Matmul:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Matmul (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Matmul()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Matmul::Matmul (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Matmul::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Matmul::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Matmul::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Matmul::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Matmul::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Matmul::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_matmul.png b/docs/build/html/classmlx_1_1core_1_1_matmul.png new file mode 100644 index 0000000000000000000000000000000000000000..eddf94c6689334b6ca0636a6709feb8af328d0ee GIT binary patch literal 885 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GT+o-U3d6^w7^KAiMQfya$s z-q`Qn|IYkq-PUPYC0C!GJN7W=rrDBqhF{a?6wC2%Qt5Te@SHUL!Xy=azO^czb(j9U z?4A4Tx|%2dkA`b=4c~=4kII~T`_jswdtDpjrfpbj7yaQ+is#0sGY_4zf2-kHx_hzd z^1z_wTX(*WeytVizpL^U&!jESxi*`4L#>3(bP*6Dxug_TdApLKVetFc$m-I}Pc zoK`EZ6-L`%nY2VWChF|8`|sBtRZh9W!Ih>ExA4;YVy^NxaO>*o-7}PqZs0xxDR+| zFjjEtGUzX4eW1F9VNZ)#!}JSGAA}H+KbfN}pWDc#cuK)d?~C%B^hf7m`SB$`FQ2yi zy54tE)V_tK>W^{~E_$v0d1)ovon^a#L6G}RwW4Z6|L(Ngm+rp1nm?=Qw0-G4-|UmK zmn3JqU)#QF`aQWXk7rfiHNTpCJL{?;I2g*VT6ODIXWP#$&zAjGuv=j3r8+iGt8JwR zx9^Ev`>SS7$m-48qjql+oPXg}-p@<3bnjQ*-RoN8yV_2yf6J++ZR%^jY=6F6^xmqy z)|X}1F5{RxS;I5``N{VO^#A_Ylk+uj?e(?)=FiXa{8dWb5Yk*~BBx@!`67p|v9bQb z)U;< zUzioN_U5F>yR*towoeOt7Ib(0q`v{HKU(N`8f~_Xo_bwq_EP8H^RE5LJT^6Pds3a- zYQK`8wPkniJyeFIipNUtk|#x8zPD!UTAyNInpo-1nd|Sj_4C;?RY&^!XU?2{!N~Zt qP*hr4oYUzuXHGAf8sv5OBfoR(?0+Hrs@H%Sgu&C*&t;ucLK6U3NwshQ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum-members.html b/docs/build/html/classmlx_1_1core_1_1_maximum-members.html new file mode 100644 index 000000000..d185625c2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_maximum-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Maximum Member List
+
+
+ +

This is the complete list of members for mlx::core::Maximum, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Maximumvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Maximumvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Maximuminlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Maximumvirtual
Maximum(Stream stream)mlx::core::Maximuminlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Maximuminlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Maximuminlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Maximumvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Maximumvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum.html b/docs/build/html/classmlx_1_1core_1_1_maximum.html new file mode 100644 index 000000000..a2374f87c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_maximum.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Maximum Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Maximum Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Maximum:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Maximum (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Maximum()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Maximum::Maximum (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Maximum::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Maximum::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Maximum::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Maximum::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Maximum::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Maximum::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Maximum::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Maximum::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_maximum.png b/docs/build/html/classmlx_1_1core_1_1_maximum.png new file mode 100644 index 0000000000000000000000000000000000000000..d888f70916cb45d503beaa60665bcbe9b4d8d205 GIT binary patch literal 901 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GTyo-U3d6^w7^KI~g;z{92< zzvR^W|3}o}N>$ImAn!?wShjSj9M$}O z{^FgNf3iI;aqn|{eR590s^Ybq_S~*{6lB-6QRdl(wRX`T{+^qYblSLeZTzjsNq;I9 zmHIDU)O_*Tx9QVDi^};v|KOQ)<@1^-JMVDbl^OoGqxZf!pDkPVbLF+gy0d`lJ2HRD zmISS>joyD{(h^~(DBsiH_il|oay_ap+T+r^@-5d`bITu9)dgHrGF@DGw7+bT%3rt( zmrPS}l?-2fEvlXQfmR;FW5Z?F-kL4BJ|k-K)sDF4tfzh~9h0LN0)w~}ys{WwJarp1 zR`sM9g8z6{e0!2+5z!QCqKB$%S}I^#a|Y0Sbd}QLd97oq6d0e9G_V8ox{a znX0UOs=j5O%N*aBt*4Ug=I`L_50B+5zVYo__J@+i{l8CbTUs6S*?(5kb$z?{%S@l^ zUaH$-{8#n+h2y=CcRaiGaPpr!+oJ52t!@k1vV8Tm$Tzp5|Gth{yY=jm4Rc5fBDf59r}XQ_#9{f8at-9KhdyDzpL93VHU_T7CSR!yMFrz!^%XZ)*3jR2Z%f>v z-aNBqSzCVz72;@vuBph$?}~% zt@7xZezof2%*wFj4ZPDg{me_dRv4%I{^6>PvQ^u(^i|6?ch8g$%aKk?lMT8)hl{qCpz?N)|azyI-4Jljou+6PRHbJYD@<);T3K0RS1%zPbPa literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum-members.html b/docs/build/html/classmlx_1_1core_1_1_minimum-members.html new file mode 100644 index 000000000..34236c406 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_minimum-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Minimum Member List
+
+
+ +

This is the complete list of members for mlx::core::Minimum, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Minimumvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Minimumvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Minimuminlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Minimumvirtual
Minimum(Stream stream)mlx::core::Minimuminlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Minimuminlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Minimuminlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Minimumvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Minimumvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum.html b/docs/build/html/classmlx_1_1core_1_1_minimum.html new file mode 100644 index 000000000..3407dff49 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_minimum.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Minimum Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Minimum Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Minimum:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Minimum (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Minimum()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Minimum::Minimum (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Minimum::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Minimum::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Minimum::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Minimum::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Minimum::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Minimum::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Minimum::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Minimum::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_minimum.png b/docs/build/html/classmlx_1_1core_1_1_minimum.png new file mode 100644 index 0000000000000000000000000000000000000000..46ca64b2d779f71310a4c06ab7a53f43090b24aa GIT binary patch literal 892 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GWdo-U3d6^w7^zU}*@AmGN& z-#PR5{zu!xynkktnyfoLdv4q0OJZ%zJin&TDW1o_N#(3thUcW|7bdCb^Q~3!th@Bp z>ecQm`L2`ZePE8_|Fy$;ZuZjLyzHf1mGM1qx0&6(o}ORxcT?Mq$2wq}=mT^{v)OGcIFs&4<) zc8=<+O_zNQ*7dYHxb*6hn)%U()$FdX>Tgx)-LojvY>RE@(z*?6JhQxvk8bx{r1Dp4 zp{i$`llP=OEz?vgXNGG(pR=6#L2wMivt+-ur&GPwTW%FwExA|iYS(0r2kKiG_OysK zOuxYNK`4qL&WZbgcLrkxr!Is3Le>YW2+2?ETZ+zA+MSqWA`LUWC(3iupEYOYzj0sq zzN6B+^ljw+Rcw3fE%L7Kv6fl+?)JK{yUU_^H8Yn)uid;&b^EWEHBs3o9+oL*Ev;ko zlsdibVf7v7wZCfSthwfxx3x@FF#dv-&(EM)Q|u+)T~yn*JXF3rJu21tws6?1-{14v zzh8ZK?n`6XW)61sz)5!Jy7xE6|N2~c`<46Fd)t52|68T{@;axk=O?4}VdpIU*$*t= z!EomIoT%xir^FR!>4q6s`K}UG7d+sPH9Rl_>nVT6(kYi}CY(8wro83M8DQYgoH_k@ zlzvwJ+3COUeSW&KKJN9ke|5&j)$PG+x123s8&Y<<->WuycBs(QYf;ypN9T55x0^ka zcTV`eA?f;&?0IL;u*t0nduzDF(`fU&+NtGh^)4OMHs0LWpAox$ z+5MG=-M7B_r@s2?x8Gf1lRwqw_0GQ0n{%@5ljqt^3*UO`e06&gv)bM>Esgo(^W?N= ns&Lm`Ff#tU3FPXMKg_>NXZLNL%5(sjPZ&I1{an^LB{Ts5fF-&+ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply-members.html b/docs/build/html/classmlx_1_1core_1_1_multiply-members.html new file mode 100644 index 000000000..9754017cd --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_multiply-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Multiply Member List
+
+
+ +

This is the complete list of members for mlx::core::Multiply, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Multiplyvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Multiplyvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Multiplyinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Multiplyvirtual
Multiply(Stream stream)mlx::core::Multiplyinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Multiplyinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Multiplyinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Multiplyvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Multiplyvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply.html b/docs/build/html/classmlx_1_1core_1_1_multiply.html new file mode 100644 index 000000000..85362801c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_multiply.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Multiply Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Multiply Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Multiply:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Multiply (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Multiply()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Multiply::Multiply (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Multiply::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Multiply::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Multiply::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Multiply::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Multiply::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Multiply::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Multiply::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Multiply::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_multiply.png b/docs/build/html/classmlx_1_1core_1_1_multiply.png new file mode 100644 index 0000000000000000000000000000000000000000..518065cfbc82a3186a616f79145caccbddf2a3ee GIT binary patch literal 909 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GT?o-U3d6^w7^KI~hqAmGL? z?`i!0|B>raq5)ej``%71UjA@NMEgUI=D*(OEc@9bC%GwaQ90?I;WK`)QtvEj07Z_SompAj|rYDZji)=^KE55iFlaZcO^ zyfYXpICUBH7qUK3-NLY^MXX`^1*Q)|2+5!PTZ+zA+MJkVA_X&9X6q!CpED2ZA7Ao& z`SjXX_On#B);PabdX)2E+0?MlJA-*6LeuJy~^S{ALiZ^2dFpKDE3@6UXfC0xHORK7bM93VNbr~lq9 ztG9Zu^<~+$X@XY3P`P)m`~3m^zd!%nc(pjne)})$`m2+^yl$DMQYjg({d`V4^Mlws z49Ajv*51zbSZ}yhZMEQDwyRx}IUcBE4G+w~`XryRbjqcQ31`j#lgODfds?Q=oGCv2 zTKBW;3Q4oCa%X3MP>$;V@pN^@-U%PPwAa3T<~KE`YWdTu#A@S%$xHzol=kTU056U`d{X;QWLA| zX=%25qDrStzM6I1_Hq>WwKH!zf+4BIeB(#Yw53q+hW@pyU;Br&s5^} ztM2 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Negative Member List
+
+
+ +

This is the complete list of members for mlx::core::Negative, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Negativevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Negativevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Negativeinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Negativevirtual
Negative(Stream stream)mlx::core::Negativeinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Negativeinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Negativeinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Negativevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Negativevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_negative.html b/docs/build/html/classmlx_1_1core_1_1_negative.html new file mode 100644 index 000000000..8d4d4cdb3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_negative.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Negative Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Negative Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Negative:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Negative (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Negative()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Negative::Negative (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Negative::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Negative::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Negative::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Negative::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Negative::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Negative::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Negative::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Negative::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_negative.png b/docs/build/html/classmlx_1_1core_1_1_negative.png new file mode 100644 index 0000000000000000000000000000000000000000..f7a0d33a0efb549f57aef90f5e4447729e2f42d9 GIT binary patch literal 929 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUVo-U3d6^w7^KHRj*ipMQ{ z|2d1w|LXfXm^4Lte($oiEq+v@=_DnVxvGBU)@vUZcwNd^GEK$RQ`fUg=z`WHk>z*d z_wKy>G1=o1cb((wBtMI*JHsO9<=Pc!&OaJq_iV#jyXX&pPt8d>ZQQyx{#N9qKRYh$ zS{@jfzGdg%m8nzjOwE&OmOOz7$Ko1xBHR+qkAi~IIF>aL#s)v9f-#@>_U4R!Z7 z`>iU?c|AR9Qi+qt>dc?D`Lbdkw@$k+HmT%^?AlVt+jDowm7j`Ql4jEXWAR)K&-$J# zla^?#RP_uD@}9J)XLV@y+QW<=H1il98!o%{)@;f38Bvq3cEmMjJ@sShm>k6r7{smM zmBr}dsoS8jl2t%;E5o8CVh&TTFm+5qNd9EslC`$fX2m2EDVWJJTf0<_&OEIDe9G_D z(`&8jcs#Rywq0{Qa`Q;&YMp1F8)tv0x$-ob@AkYp-g$a&9?RX_R&u`hpI`8YUq_$c z-STXYt5sjwGW(}_zxG_{&pucB{>k%gS3eo<&hl_|FS}~h9b29K-(vnXt1kt+1>!RQ zvjCN+9^AesR(HR>->R!G*IvuZJz~CPm+t<)WlycQ?~Xb$Kl5kqN3ZPI3uWzVkKg`! zmn;A3y|*v>!h&1Q`Au2U|M^+|gX#Zj>T$E?07xgBlPxL+O1Bb$Bu@^(!E!VjJqdArKJVN zW$#`2m&1Ox$z1-YzxP~|sb83y_H9z|+D$e$YaR&s*=zkv=Bt*;|177|czgTr+q1UK zO;5|4lkLknP3GvCezVQ?ru}9~(=3meSe`jk787#%`G>0!_TNgLnTqdUTU!6~_32Z) zjLyc-J(Dwe+sv0L2s?iW6L#_s6v?LSJ_iuU_b?LsKv^4cUTkEf;rNygn jJ#z+>h`clpf0RGoG}~Wnbx}MpGckC&`njxgN@xNA5{=7r literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html b/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html new file mode 100644 index 000000000..93f8795dc --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_not_equal-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::NotEqual Member List
+
+
+ +

This is the complete list of members for mlx::core::NotEqual, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::NotEqualvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::NotEqualvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::NotEqualinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::NotEqualvirtual
NotEqual(Stream stream)mlx::core::NotEqualinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::NotEqualinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::NotEqualinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::NotEqualvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::NotEqualvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal.html b/docs/build/html/classmlx_1_1core_1_1_not_equal.html new file mode 100644 index 000000000..c394128dc --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_not_equal.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::NotEqual Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::NotEqual Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::NotEqual:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 NotEqual (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ NotEqual()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::NotEqual::NotEqual (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NotEqual::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NotEqual::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::NotEqual::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::NotEqual::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::NotEqual::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::NotEqual::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::NotEqual::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::NotEqual::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_not_equal.png b/docs/build/html/classmlx_1_1core_1_1_not_equal.png new file mode 100644 index 0000000000000000000000000000000000000000..2067d9710fa43a18393232e768dc93f6d029c187 GIT binary patch literal 916 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GWjo-U3d6^w7^KI~g<#lxl^ zzvR^W|3{_=ad>j9-M#$w#Q zapyJLo=I$rdt>8om9N}o+!O3RVadGoo2wR=+?+7!+?K6ZPuIm~M?9S#6&+~;RDXDR zX!!-#ZRb+|26<1?E66ley?^~$rf_`r)m%l-^m`Yzq;HiaS;{Xl^*oyq_tAZ>hG%{6 zg-I&<3spViKrTJAdTQ~x%Zwj1^B6uGF2DBHZ0YqmQQcR2;wERgsx$rI)Me0L$ofEa z3&Wlkv4-gvm_7(aF~m7>AMnm#tU!|d!yIM#+(tgda~9n6b5Wj?{;at?f1T&|^BTX3 z{z*)_@`-n?y74+Ov9()fuUEg?^fvVSVXL^E!wZWuR3p+eq4EW>%+;fw&reoWf^O>WQlR?wsp^@-Fg-O{BqQCsXX&VVXy8h zcs}2l|3k2RMd z$=`feRkh%%O5%miNnZ-@{o(y{$$q|k^xsw2&aeGv|9_S0%k{!hlb)o;t({XmkFUYM zkaAl4cxf*EsW5ZqS+|UtGb=fDjg9r! zZ@qV=&c|L%H@*L!k+F4g>CeYeum6OdKJ%ujeCs8-0^60>^7Z%0Z!6jITJFyJd;dec zf8Tzet;<-pH6^V7)dscaJGS(dxlLVr^;6C5Gi#oEZp-RV3;yPOJ|Q+Z+WlR>!v2+^ z;U$&d45Htv{aYBe_m5I^?>6_>=gl{2{@k~XZMMaV9ox3ux}BD0o1Zb2H`jV&-l|V? zR$YC!cJte!iMR5@W3$q({`j_j&*#jArkkhDoY{QA$oR8RR9ae`*xD&Uv- + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::NumberOfElements Member List
+
+
+ +

This is the complete list of members for mlx::core::NumberOfElements, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::NumberOfElementsvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::NumberOfElementsvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::NumberOfElementsvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)mlx::core::NumberOfElementsinlineexplicit
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::NumberOfElementsinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::NumberOfElementsinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::NumberOfElementsvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html new file mode 100644 index 000000000..018531ce9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.html @@ -0,0 +1,396 @@ + + + + + + + +MLX: mlx::core::NumberOfElements Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::NumberOfElements Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::NumberOfElements:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 NumberOfElements (Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ NumberOfElements()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::NumberOfElements::NumberOfElements (Stream stream,
std::vector< int > axes,
bool inverted,
Dtype dtype )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NumberOfElements::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::NumberOfElements::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::NumberOfElements::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::NumberOfElements::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::NumberOfElements::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::NumberOfElements::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_number_of_elements.png b/docs/build/html/classmlx_1_1core_1_1_number_of_elements.png new file mode 100644 index 0000000000000000000000000000000000000000..2364d3242ef674772f7577d13debc63a620692ac GIT binary patch literal 991 zcmeAS@N?(olHy`uVBq!ia0vp^8-TcjgBeH`I80mvq@)9ULR|m<{|{uoc=NTi|Il&^ z1I+@7>1SR%c<=xyZhAIs2~du+B*-tA0mugfbEer>7#NtdJY5_^Dj46+eSPk=0uRe{ z*RDVR{A(_mG|Zh4y?>XL$2tj9@nkmkAH8#mWkJfo{#E=ghZ%A6j`Od1pZLT2`HU?}*iz+;w{+gSOZdN9^6W*W**i?nCPt zXWYmv3;DaErugf!+3x8ICxi3CBc9(6nc#CL`TB=4ssED;%IC{+t?xebMd_rnoxru- zzZV?(-CL^ZGpW2nFURa|8SndPN49RbH$$cJ;PFGHl{bPH#g*H+D)|~u{{KnSangTw zMW0FYS`0j^gbt~^JlS!`^0|!w!-GRC49bSHKA7;P+06Rz#(LVvf89$7#?_XFgK_Z5SSi+&a$7y$g|IxWAYMpkX11p>XY`I7pl9ZzPE1go6lc&Pkh2@ zFK*Yl!!#n^V(LvjajmOS^FD2F+1hpW!Ox>+@8+&u8a@Bz#Agqxl0V0B>np1j-Ous+ z>+y1p{;$Gkk>x6Pr!UEL7rnOETvSE)dyZC(NHj297VQ_l_g3k5lM| z=YFW_ycy2npbi>UO*@>G~Kh6XNPv%bTnCb0z|9;)I_x_sc-&)O{ z`?}wl^86w&jN+d^oar}7wod(z?EU+k^(Vhi_ij@8?+5eI8LEewmu4@^q>d97EgL#I z)D?A<9-WN)@sQ=-!?}KUZPw@Ri_^Kje)nlc-vjqsr~jCJ$aGsl_JMsLY#vWKuC>Ng zv)=ng)Usa-zM58Ulf50VYm@f}p^)F(ej7!}ZsgR@pL}B9%KhbSzgBR(Ixy?_m-Vu_ zR@JLzbGWYV42_U}#kI|-WZ!a?Oy5Z}(pr4iZC$nW((~H0R{uSA>qEY?AWZb zuJv}1@$A^QiPnGb&bwQ_=J|&--uFhcQm@#ny12M0w79z1a3%^0`a4Cq9GNWG8tj#+ b@Q*P)Wb?{5itP=+tjOT$>gTe~DWM4f + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Pad Member List
+
+
+ +

This is the complete list of members for mlx::core::Pad, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Padvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Padvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Padvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Padvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)mlx::core::Padinlineexplicit
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Padinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Padvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Padvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_pad.html b/docs/build/html/classmlx_1_1core_1_1_pad.html new file mode 100644 index 000000000..97fa86571 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_pad.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::Pad Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Pad Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Pad:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Pad (Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Pad()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Pad::Pad (Stream stream,
const std::vector< int > & axes,
const std::vector< int > & low_pad_size,
const std::vector< int > & high_pad_size )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Pad::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Pad::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Pad::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Pad::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Pad::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Pad::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Pad::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_pad.png b/docs/build/html/classmlx_1_1core_1_1_pad.png new file mode 100644 index 0000000000000000000000000000000000000000..13b64cbff9539015259689bb35502cdb3ba44dce GIT binary patch literal 874 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B+PJzX3_Dj46+eckuUfQL;z zUggR6`bWXLww&!v64+oFwYZ04?PEXL^zZtX&u!#WJZJS@n53e=P}MU|CdzZtpCzx; zx0KEN&*NEjfPIVKwH?8CuY29zy>_zF)3p+}uS=C(dCsx#USzOL<+Kfd-S>J0S^8~R zmpR2mqV(OTOWVAU&P|+W)U2{~PUz}0mqMJqUahQNJ@xOtYa33#&$@eSQIzUR8N0Qg zSnjS2yMK*;tIA70y|t^J-oE$B?Bm9;N*$Gz&n4H^I=`O0<2-w+Z;6KfH&Q)45+&jveVE63Gu9#(CW(8+&oOLlfXstyjL!A@%0q+dP z3Qk=H{e`R#RJSnfX%TCfeu3$OP!vNPl4QMu_VYQpki=Wav-8)+^rS0?UV%;8o7wDD8Z@)iy>L&et@>`Gc-C=f_ocHRdRPwJi z&7bruaJ`PcDF5EO630WYxA|*bzjjqCc~+K(r>*bScSqmty}Iw*?ybAz)|oE~$-1xL z`S*s|pA*(sw9X%&o5fqTdfRR3kDO~0PrqNB_4DB4b7mh+*M5qyxpd2Xi(TUHKRfQf z`4IT_nGgS#RZ2X5i&XZ`RsYXc_j3OE(_!*gtM0z~|1(_IbH5*TLul!gXU!>TX3`5& z)1IkrIdf)@&FZgDyw@y0b*Ag#sSlZJ>&{wly<~dZv}mr#&itb{fPu5fuKI)A_qesH zcb82&b1JYVHo41tyL@TP?&WD|w(qj0%K4g}ne??}Z|IBJGv|I{T)XMC_Wtu(xz)u% zYuD|ceb&s_`e`LJNo?%?9g;8Vo>w6By)ydR?as>go+ie>)!=SvnKpAKFf{~uWm^4X Z-12>HmaDA(T41hV@O1TaS?83{1OUy`u-E_q literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_partition-members.html b/docs/build/html/classmlx_1_1core_1_1_partition-members.html new file mode 100644 index 000000000..101ee7dc2 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_partition-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Partition Member List
+
+
+ +

This is the complete list of members for mlx::core::Partition, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Partitionvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Partitionvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Partitionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Partitionvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Partitioninlinevirtual
Partition(Stream stream, int kth, int axis)mlx::core::Partitioninlineexplicit
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Partitioninlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Partitionvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Partitionvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_partition.html b/docs/build/html/classmlx_1_1core_1_1_partition.html new file mode 100644 index 000000000..018687d23 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_partition.html @@ -0,0 +1,472 @@ + + + + + + + +MLX: mlx::core::Partition Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Partition Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Partition:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Partition (Stream stream, int kth, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Partition()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Partition::Partition (Stream stream,
int kth,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Partition::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Partition::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Partition::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Partition::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Partition::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Partition::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Partition::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Partition::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_partition.png b/docs/build/html/classmlx_1_1core_1_1_partition.png new file mode 100644 index 0000000000000000000000000000000000000000..4259b6ba026b79860e910b79b230a581b9e7ee5f GIT binary patch literal 888 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GW7o-U3d6^w7^zV3Uiz{A$g zzw*;>`;VEMc^e&qwrA~@S*z`~Q!+j2!@KyRbCq@{CYeYtRP~H=@}9KEV#_3zpPt+A zXT1CRC*9)`v)zJHqvLa5MQE?Rzim#4=KG@&{N6Xh=Ij3Wmg2eb>CU#P-DT63RL#k_ zdr2eH;MSe5YfGoDQa@}{%I0~iIBc!e>2*>;LH@h9f7^8by41T*&TD^f%SiKF)$YIg z-Gb(ZtD<=37XExkS`Dtxu%UbU-Tek>n^qZs0xxDR+|FjjEt zGUzX4eW1F9VNZ)#!}JSGAA}H+KbfN}pWDc%c+P^GUbR-mvu@gl%a$*{NvH0g_n*fz z>u1|F&m$L)ESehj`RS5b1^=&n-O!R7AK$c||68H`&21&$tG~`(ckt<=bGtKjD<@y+ zFH5UkUG?(sL5p=e&1-^fx33M8-MqSO%9YKluLb7ZjQ)MTa(meJ_S>y1rhbc`u%vDc z_r9mHSy%rS2CO}JFYE0ow;HGH9ZyeZZ2R+|G`;)BtgyRc{aa2oZIcZCOXRIl}xTi1mdSNX0Ioy+k+9cy@C2G%F`EkRzHb{=VIn+2oN(&C&>pE=X*{@U~7 z(#dDm$X$PJbmolQN6xhe=gjx6UnIPCOWoa~Al1F%$+kL)p|xHg#0spV`L6!Gy>;fz z3C3$vlZ;-wmqtv=-+rz174xp2>E2aQYpu?&lS(^TwQFyvt<5hTz^Hb4Ddy*m0xe)d$>T$a^~ntbhEru9ctUE|I7mA0HYv!`X+ o%$d_K7#V*yS|8@6>3iJ%!SmVsi{ckD0rLohr>mdKI;Vst05o;I;{X5v literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_power-members.html b/docs/build/html/classmlx_1_1core_1_1_power-members.html new file mode 100644 index 000000000..6bb8d6502 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_power-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Power Member List
+
+
+ +

This is the complete list of members for mlx::core::Power, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Powervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Powervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Powerinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Powervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Powerinlinevirtual
Power(Stream stream)mlx::core::Powerinlineexplicit
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Powerinlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Powervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Powervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_power.html b/docs/build/html/classmlx_1_1core_1_1_power.html new file mode 100644 index 000000000..7f1bd4ba4 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_power.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Power Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Power Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Power:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Power (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Power()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Power::Power (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Power::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Power::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Power::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Power::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Power::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Power::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Power::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Power::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_power.png b/docs/build/html/classmlx_1_1core_1_1_power.png new file mode 100644 index 0000000000000000000000000000000000000000..7ae727a295912364d435a4d757bfc1ebac0c8d35 GIT binary patch literal 900 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GWto-U3d6^w7^zU^CVz~h>q zx8&8n|HAr9RW0O2g0~t+_NEtK5s@h_nDaLN$(+wM3X`AYw1|0Daq3R`^5}w>%1iaT z|J|~u|4g3ngfsp^)y89UUqxu=zTf8Ob1L0%Tfb?}`tx!%`!_GU@mVMFsrTNlC9?Wk z-eu0PkSNWo&b+N!oIc}tn!?Fx;a4|B&dQq5XI@-*QQ5@toFF|@bV`0 z7Ym|lqvvN$e8Q-=cGc6{_g^f8JG_sXe2kw0otv*quzW0KX zr~X3MNpViyDtphYo?3kFG9$zMM{EqA4VPbgYqs?IoT%=rJ#mw>j(V~%*e_&dh;!m* z*wZ4$P{FCo@Ifew;eqNFh6COij1AK-Ffr&OEBMJAW%=Aj{)CF9G|Yl8QJpG(*F4m} zzU23Pjo(J~Jd@V^W^CxlMj&CeJ&1!|s|v@O+^?aWyIt8@W|7&#b(*sV>-i zlI*j`a>;tO`>*6*IpZf+9i{vFVu-WXtCdx&OW!i7u6#Z@YTN1Up3759R9`k|N8P@+ zXGK`KQ`fb}@ujho_;xG^-CA=y`mWlwLa(5971@GVP?!jv-lO|k?t<=wC59g#D|^N{ tc~9EYGEJqDQ`hs;ri`UiE|vUYp160;?rBM1yn*?J!PC{xWt~$(699*qtVsX> literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive-members.html b/docs/build/html/classmlx_1_1core_1_1_primitive-members.html new file mode 100644 index 000000000..0e781c800 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_primitive-members.html @@ -0,0 +1,106 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Primitive Member List
+
+
+ +

This is the complete list of members for mlx::core::Primitive, including all inherited members.

+ + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive.html b/docs/build/html/classmlx_1_1core_1_1_primitive.html new file mode 100644 index 000000000..12c2e073c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_primitive.html @@ -0,0 +1,631 @@ + + + + + + + +MLX: mlx::core::Primitive Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Primitive Class Referenceabstract
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Primitive:
+
+
+ + +mlx::core::Compiled +mlx::core::CustomVJP +mlx::core::Depends +mlx::core::DivMod +mlx::core::QRF +mlx::core::SVD +mlx::core::Split +mlx::core::UnaryPrimitive +mlx::core::fast::Custom + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
virtual void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Primitive() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Primitive::Primitive (Stream stream)
+
+inlineexplicit
+
+ +
+
+ +

◆ ~Primitive()

+ +
+
+ + + + + +
+ + + + + + + +
virtual mlx::core::Primitive::~Primitive ()
+
+virtualdefault
+
+ +
+
+ +

◆ Primitive() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Primitive::Primitive (const Primitive & other)
+
+delete
+
+ +
+
+ +

◆ Primitive() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Primitive::Primitive (Primitive && other)
+
+delete
+
+ +
+
+

Member Function Documentation

+ +

◆ device()

+ +
+
+ + + + + +
+ + + + + + + +
const Device & mlx::core::Primitive::device ()
+
+inline
+
+ +

The device the primitive will run on.

+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::Primitive::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+pure virtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implemented in mlx::core::fast::RMSNorm, mlx::core::fast::RMSNormVJP, mlx::core::fast::LayerNorm, mlx::core::fast::LayerNormVJP, mlx::core::fast::RoPE, mlx::core::fast::ScaledDotProductAttention, mlx::core::UnaryPrimitive, mlx::core::Compiled, mlx::core::CustomVJP, mlx::core::Depends, mlx::core::DivMod, mlx::core::Split, mlx::core::QRF, and mlx::core::SVD.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::Primitive::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+pure virtual
+
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::Primitive::is_equivalent (const Primitive & other) const
+
+inlinevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented in mlx::core::fast::ScaledDotProductAttention, mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, and mlx::core::Transpose.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::Primitive::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+virtual
+
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
Primitive & mlx::core::Primitive::operator= (const Primitive & other)
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
Primitive & mlx::core::Primitive::operator= (Primitive && other)
+
+delete
+
+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
virtual std::vector< std::vector< int > > mlx::core::Primitive::output_shapes (const std::vector< array > & inputs)
+
+virtual
+
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::Primitive::print (std::ostream & os)
+
+pure virtual
+
+ +

Print the primitive.

+ +

Implemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::CustomVJP, mlx::core::Depends, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Load, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, mlx::core::QRF, mlx::core::SVD, and mlx::core::Inverse.

+ +
+
+ +

◆ stream()

+ +
+
+ + + + + +
+ + + + + + + +
const Stream & mlx::core::Primitive::stream ()
+
+inline
+
+ +

The stream the primitive will run on.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::Primitive::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+virtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented in mlx::core::CustomVJP, mlx::core::Depends, mlx::core::fast::Custom, mlx::core::fast::RMSNorm, mlx::core::fast::LayerNorm, mlx::core::fast::RoPE, mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, and mlx::core::Transpose.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Primitive::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+virtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented in mlx::core::fast::Custom, mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::BitwiseBinary, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Compiled, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::DivMod, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Split, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, mlx::core::SVD, and mlx::core::Inverse.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_primitive.png b/docs/build/html/classmlx_1_1core_1_1_primitive.png new file mode 100644 index 0000000000000000000000000000000000000000..515b36a662977995aa06baedae2bf65ed90237d9 GIT binary patch literal 3642 zcmdT{YgAKL8Vzc35RuN8>R>Cd*EP*S0L577y?6FuZ@a@(EuW&nERiXgC6XUOK#16g`5+yc7KC?uZO4P5*dG znQXRG*c_`1M<|Y*8i%8Y3@okj7kh#VZ1j4USXf~kqLbVCh z9DNvA<63?WZe#FnQ<*Wn9^;wD&T2hVsU$o-nD3&Nl_o`S7K4_;oU>sQw}%qGnyoa3vnHi4eca;nnKtt?p06p z(%HCJKlJEjBa!cWAe!nr*p@kEKxezLGbf~bp&6U>vV^eu${=XNi5y`Y4zFKF76m_M zlL_nv966_bC`lWUZr2pabOekP$SDY_s~{xQR&q{h>H+;zt=tdX%@&DR(1!5Qiv_7W z;!9YXIcC4gF7AM|h*}(^J3@3WQV$+NN&RrIE~O2Xt9+JZrQiyyt(vKC6vNkil-I&5 z7CnboeY`Q{JP%*K(y2l$qL$yVho%=ads<_J)By?KPFv>awZT4h<9e-nXV}~>LN9Z( zfAQ0_9z%<2{n4`;rC~J4F}2a<5aFLHyjJ2KL3QZXGxmd0<>X$;eNO3(jz-k0-N*_m0HNY=-32h9OQ&it3lyk4$JTVuyXT+*`dad&A@OEr=!qH9GsPA~ZO6BQ}8%`g%bhj?}7GI(Lq$2vE zhn5JAen#va?@-^?5!u@O@`@_j@B=WlFL&GO{<_ToX^?{B8V61SKDN=ambsD~xw`k* zb?#)-GVD@Dc49tEaWe#>lPwc7e~sv$T~h{lZY6h!r0i6Fx)u;B3?r?;+3niSkfuvp z9Vp6rk8y*B-c=Bx+rY5Mai*&0yLD0{kI=q?{iKcTZnC43TUJz>11hWce?#?#G^PJE z;Mg%ey`tfuJ|>(|`)W;FW!BO`FS@ysS5)7VK5uTJ{6c+eY(kZw>m#?TPB74Y$+6QH zw4WDtJW?gt{%8GzX8NX<>S7-AAegthXxtj_c2@deSTHdgn?6~*3of+ICk=!onl z4fKF|PD3O-9IWZ8Y0KBFrJ{_hm-U2)8jCX|*WyiJz$#>0rBV826vhUkKvMQ>bG&Y% ztK~Vsw1ENzXkUmF%>|qyMsq$e;I*=TDrR1VP76#Y0sVv4=2pvhnj~+>#1ggmxt|HQ4-smt z6iwu1JNvA@j8Jdv^+uyBdcGvXxR3PhHZ;% zBmGhT|J}MUd27)?GhhQXR$oIOd9|ivaf_p=S4|e_xZ5T$1 z1o+q}@rsI}XFSPn9^eB&O zYKw$oeTf#?31@*;E}5(6z1+f!NS0(R%G3}nQ|U1)Gc<1?N&PM)75F7pc868j2_=OC z?+aO3swoc$Gch_Otme393qqMP$`57SH+|fCpKTn-U!v_m3Ef4`+D@V!bSn(ji<->WS0F|2{qtcUBnuX4R8I^PO2SRZF3)xoE) z>bzLN2Q3^vnHv);k7!3fg1M}_27jEZWqp6fa3B%bad)?mp1dO-QeyG0tOk>?+VnU@ z(YKkk102$RjV>LB18}qCn*ulhyJe4&wDj&7*~kx#hNt6K3FMT~Qo^tOIWErDfllB5xgh#)2dHCKlxw24>$Y( literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html b/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html new file mode 100644 index 000000000..b7b810ed1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_q_r_f-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::QRF Member List
+
+
+ +

This is the complete list of members for mlx::core::QRF, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::QRFvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::QRFvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::QRFinlinevirtual
QRF(Stream stream)mlx::core::QRFinlineexplicit
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f.html b/docs/build/html/classmlx_1_1core_1_1_q_r_f.html new file mode 100644 index 000000000..f65f45b1e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_q_r_f.html @@ -0,0 +1,273 @@ + + + + + + + +MLX: mlx::core::QRF Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::QRF Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::QRF:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 QRF (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ QRF()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::QRF::QRF (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QRF::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QRF::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::QRF::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_q_r_f.png b/docs/build/html/classmlx_1_1core_1_1_q_r_f.png new file mode 100644 index 0000000000000000000000000000000000000000..29056e86a6af766743e90c1c5ee33e33ffd857c4 GIT binary patch literal 520 zcmeAS@N?(olHy`uVBq!ia0vp^B|sd&!3-q-S1vRKQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=E8uJzX3_Dj46+W$asRz~gdU zc;&3$^M43OHSs3ioTj$sOK@rncVGOQGmI00y!N)bOsX|qpmNivo+V4=XLaPitKa-S z7EesMQu>0wy3b2=h3@lN%-bJcJ@IMAJqy#`@+gQOlBwbzn`#tceU$>fRx5o zzm%X0+k)$tPZB8=eYLu6+v7@=ouzDFjjdjkyk30zF|XULdoTHF&hCxQkSs09@ANMI u`96KFS5PqT3Ki4KAgk}Pni{`eQ~x$2aas_w<8@%1F?hQAxvX + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::QuantizedMatmul Member List
+
+
+ +

This is the complete list of members for mlx::core::QuantizedMatmul, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::QuantizedMatmulvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::QuantizedMatmulvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::QuantizedMatmulvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::QuantizedMatmulvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::QuantizedMatmulinlinevirtual
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)mlx::core::QuantizedMatmulinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::QuantizedMatmulvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::QuantizedMatmulvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html new file mode 100644 index 000000000..dcd295798 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::QuantizedMatmul Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::QuantizedMatmul Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::QuantizedMatmul:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 QuantizedMatmul (Stream stream, int group_size, int bits, bool transpose)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ QuantizedMatmul()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::QuantizedMatmul::QuantizedMatmul (Stream stream,
int group_size,
int bits,
bool transpose )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QuantizedMatmul::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::QuantizedMatmul::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::QuantizedMatmul::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::QuantizedMatmul::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::QuantizedMatmul::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::QuantizedMatmul::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::QuantizedMatmul::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.png b/docs/build/html/classmlx_1_1core_1_1_quantized_matmul.png new file mode 100644 index 0000000000000000000000000000000000000000..6b7d0c3461c419ee15f44d7df60a7ff9e16ef71b GIT binary patch literal 975 zcmeAS@N?(olHy`uVBq!ia0vp^OM$q9gBeI}JboKUGDrvbgt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcGjdAc};R4~4s`@e6C0*|Zy z^ps2A>mQw;e}k!S*X6laYwi{-;%4n@J$L_v<@4LkK9kJVbv!S%Z&2}+KNqbw$;Mym ze^qh&_ni};c-Ec&baqa+_uLIWuc!O#XSJ(_zE0-4y=gmZzw%|F+g8pIKaZbNI>~o@ z+T&F|Z%lSieHYaq(pz%Ne^3p97ob>|6Ws(>>2BC12zCH-V2;em-we z_539;P35I?#H1zs{Tt6%KHtsEaBdI7gKsv6K7FlfdZw(yev&=I?yymo48tA{Z-x(w zI*bR}H?T0s9};4y!7uPH{`8#UeYPi5DvJ+IQmN(ep7bU0y3xcZmCHVCFTV8SdUS2Y zd}+^bKNmfn-Ld#=*6A&E*S{=!rr#MMAADolY#|$-D}|r4ZWnpazn96O_y0y|xz6%j zp_^vAC%g;dJMCb9?n`O;rQ-X~g8Wy4f+J+-oo?OloBr*J-XkaFl>slk57C0CHL*L@40(+@!gIL z*)JyRJ$dqDDRs|sy?@^(EqN+EP2Rii@TcOF@8?)}*13QJE75b39nUn&=QZUF2WE;h z+?nSdu`ka5LQTt4g`W&((<9t^nH$6tIT`F)R2hB-I61eIXm=1Y}Q<>B2FE_ zvU^`NEDf2DF3Qf_QnhoEO45y!sf$C->hqTtDXq+$AN8tgquJ|%_4RX?yPv94`}$m` zH1XY&O?vab9kHJFw*0D>I4}gSeBWevl|3t1)pJ|g*IQDaN~PARJLk5oEDSw*+i~6Z zl(`$?b_JcEE495m<6oe`^mWOaN=4C!f?STYZx9jusHo%WQd7~B>ZN&iZoNL=3a_tE RjgNt{sHdx+%Q~loCIA&l&Po6P literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html b/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html new file mode 100644 index 000000000..fbd98bb2c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_random_bits-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::RandomBits Member List
+
+
+ +

This is the complete list of members for mlx::core::RandomBits, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::RandomBitsvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::RandomBitsvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::RandomBitsvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::RandomBitsinlinevirtual
RandomBits(Stream stream, const std::vector< int > &shape, int width)mlx::core::RandomBitsinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::RandomBitsvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits.html b/docs/build/html/classmlx_1_1core_1_1_random_bits.html new file mode 100644 index 000000000..08a208df6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_random_bits.html @@ -0,0 +1,361 @@ + + + + + + + +MLX: mlx::core::RandomBits Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::RandomBits Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::RandomBits:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RandomBits (Stream stream, const std::vector< int > &shape, int width)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RandomBits()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::RandomBits::RandomBits (Stream stream,
const std::vector< int > & shape,
int width )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::RandomBits::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::RandomBits::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::RandomBits::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::RandomBits::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::RandomBits::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_random_bits.png b/docs/build/html/classmlx_1_1core_1_1_random_bits.png new file mode 100644 index 0000000000000000000000000000000000000000..59b478af106299b15f4fc209fce70504a5768320 GIT binary patch literal 920 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GW5o-U3d6^w7^zRmlrz~jcx zAN2J1{zvA%pBWpMK2ug-Ok#cljfKF z|MJc6dcUHl{|~mc-L~%no=0WYZvVN6W2Z{;-&xb&{=6xE|NR-oZx)v~y>;K~9b~Du zGE0U#ein45%KJqj7lXRTa7n8cWr|3l6lq*5kuLXX4zI9h<-KwoqT(<(%*F>eV z&rMma`+9lQq>?7@)mlH3b4}ZS#H`NY_ROku&9=UXW$NTT zX-~^El}b)s&re3{!p>RpvmaQ#gR$rMjHvCmC&VRZ`GzT1@vah`%ke;c3&Wlkv4-gv zm_7(aF~m7>AMnm#tl-pT&|k>I$At!3V%TY>L&mUp%5ng4pdYUjJjr}ma^J7wCQo-=pt-K^I?tma?K%GzPq z5%ls0n`hSXst46~)`i`aSHJn{N?7i8(MQ}-cPHI<3;i_rc&zNB>ZsIco6EQOw!GW$ zwrYR9#Y5ZM?epGcgnGUxS)@`sNBlqAzYF&Bx3Aq7`nq=Y|L5^rRsLFmJag>oDa+>% z8GeYEH++-`eqAPgDcfi5$t;)kj$0QAGyWkactBzIuOZA!vzPtknH&?R(`U|jXUv>g zX?eXWu>768@#eO3Z+2BYN;!Xgi{JMzTfS8;4~vRV+^4Zk_hiKWXQ{4r+m2?-U3>cV ziBz@k)tbe=#j8!(ebU})G@rY;ZRJ;ku9JV=Om`kN^ZshExwd3eXqC*S#A$1mu3Y;} zw@*Lddf59Z`xk|V{j4+VFWz}=)x0NGE}KL8Pv@9!Eq#`Ja#f)HgQ?DIy#7T*pH-Zf z`E}I?@2#O*VwMYc|I?ap;;;AidRfx)Z8dW<(!Tv}nKpCg^b1DDpM|2*(trWurMdLs a4|dzqSq3|FZ=V9@AqG!ZKbLh*2~7aXhR6p1 literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce-members.html b/docs/build/html/classmlx_1_1core_1_1_reduce-members.html new file mode 100644 index 000000000..0d81bee6f --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reduce-members.html @@ -0,0 +1,122 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Reduce Member List
+
+
+ +

This is the complete list of members for mlx::core::Reduce, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
And enum valuemlx::core::Reduce
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reducevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reducevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Reducevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
Max enum valuemlx::core::Reduce
Min enum valuemlx::core::Reduce
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
Or enum valuemlx::core::Reduce
output_shapes(const std::vector< array > &inputs) overridemlx::core::Reducevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Reduceinlinevirtual
Prod enum valuemlx::core::Reduce
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)mlx::core::Reduceinlineexplicit
ReduceType enum namemlx::core::Reduce
stream()mlx::core::Primitiveinline
Sum enum valuemlx::core::Reduce
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Reducevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Reducevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce.html b/docs/build/html/classmlx_1_1core_1_1_reduce.html new file mode 100644 index 000000000..40838cb85 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reduce.html @@ -0,0 +1,472 @@ + + + + + + + +MLX: mlx::core::Reduce Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Reduce Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Reduce:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType {
+  And +, Or +, Sum +, Prod +,
+  Min +, Max +
+ }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Reduce (Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + + + + + +
Enumerator
And 
Or 
Sum 
Prod 
Min 
Max 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Reduce()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Reduce::Reduce (Stream stream,
ReduceType reduce_type,
const std::vector< int > & axes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reduce::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reduce::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Reduce::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Reduce::output_shapes (const std::vector< array > & inputs)
+
+overridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Reduce::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Reduce::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Reduce::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reduce.png b/docs/build/html/classmlx_1_1core_1_1_reduce.png new file mode 100644 index 0000000000000000000000000000000000000000..3c46700dbf54fb2d8ebfbe8738d00178b38a3c6c GIT binary patch literal 895 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GU%o-U3d6^w7^=Jqc(;9*nu zUvldIe{uO#?kBvjzU|s=Id^r>tfpkP<2V0$owGd89yuvZd5g+P?+njL;_YFRRP0~= zlUy77EZ%w2ybrul;@Njrp4)oqTi)v>9zQjXecRe5`_;1Pefc#J_j#&0{{z=g3hCos zTP>x^>o#}yo-J>K3e#sSKcnEO8-7)5v&q#7ld4|b%Sx?}%f9h+c~t+l6=9xQ?f$E6 z9M!MBy8UbAG?lpzF0G2%_jYZb@citl$DJls#a+B*zQwNdcTK=GmC(giN4NVeQu!;j zP}MWe$$QeCmT4-LGsCr?&som=Aovc$vt+-uw{yMLTW%FwExA|iYS(0r2kKiG_OysK zOuxYNK`4qL&WZbgcLrkxr!Is3Le>YW2+2?K8BgbYu9={kX$CX-<PhC`YCnYgTMQY z-Y?wx%s0sX)zqg;A1&;ce!Sz=tq-10Z|#ao^$F(HT)BMZwa7R3qTjx*iCX>oVU@Di z(r>(zuJq>~?%(6Q_EXKAkn9W5TkqcM(9htFu9FWv%`aEJwex-E)Z>q4UfZ|uE9dR! z$N#3a_h*@>S1kw?O)L;neffCL5AzT0|9L-*>i2l`Lg|c{krb-C*{OvM+DBdt?0KF+S+~g3|mZ= z?#pdGC)db)Ek5PlFJBrX8>DAuEL{-0_R`{&ZIPSit=itEbT-L&GvoYATD$k`w7B-F z|EP78`=QHOrF%nIs(RbD6eGQAZHr_0_KJ%?~)chm0aocNZh5Fa; ux9Lnv+wTH*;KJ0jXQ^>(rv%OFvu~81!@5YVf*qJu7(8A5T-G@yGywqcmb0e- literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder-members.html b/docs/build/html/classmlx_1_1core_1_1_remainder-members.html new file mode 100644 index 000000000..e2fefec65 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_remainder-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Remainder Member List
+
+
+ +

This is the complete list of members for mlx::core::Remainder, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Remaindervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Remaindervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Remainderinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Remaindervirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Remainderinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Remainderinlinevirtual
Remainder(Stream stream)mlx::core::Remainderinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Remaindervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Remaindervirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder.html b/docs/build/html/classmlx_1_1core_1_1_remainder.html new file mode 100644 index 000000000..f0fc852ce --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_remainder.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Remainder Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Remainder Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Remainder:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Remainder (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Remainder()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Remainder::Remainder (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Remainder::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Remainder::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Remainder::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Remainder::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Remainder::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Remainder::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Remainder::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Remainder::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_remainder.png b/docs/build/html/classmlx_1_1core_1_1_remainder.png new file mode 100644 index 0000000000000000000000000000000000000000..898cd6373753c3bfca950f19c754d686cb8c04dd GIT binary patch literal 917 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GT)o-U3d6^w7^KAf~!fyY&U z`b@uj|2yrsa&azLHaBm2zM1u$K*2|fEpOwW%=uiSF!@PNir|zULk1lwryi{LX zzj){AzndpM;k92-YIt1sRgAW5_U*io6Sk7aQ>DW`2eYrw-zH#s&S;IzjYyT!-4~(` z8x&s1$+=tlW>S@L&tl~XDW^lXl}u5SPD$!tyEg3G@2JCC_E&G6K6FNDwNQT9^U}Df;$jT7s z#Lcj$MU0_>QB5 zzs$S#-LaSUjM}%b)bLSG(nYWKpJj{H-gw*U-QKY4mr+IK#lC8{)epbf{au<`PftcH+o9wdnL@@7%4Ye{Km2KU}7~C1`h( z%eS{TK1h|XSbF}r=Qg&jsc*d$_bo8h`FY9bB)?tx*3S2tFOv(GhniiPx1lKhzPWYG z=DTxVHcsmN+7tGT!BSK!+3YyZukpEc=EsZ+Pg-ZQJG7N5J!$T0sA zBg1FI<<}k?F1V zYu(3g3knHJCdbDf|F`0rc5&Qwuh(<7R__iJ{FWj!J$J?_f5|^@bP`W}zO}V{Md9)E zysg`AdM?R-wW)StnP}^+m4{@MU)@lf(DB0T+_Y~Qn-nK4IesPVy;WR<_=Do8U8iRI z$u72=U2-k8qa&|8s&w!2o3FA~`%X9+S8TcRpR{i9r{+7ieRHy2d!4>At)t`U@=)`` zb6l!V{F=Pt`O8P|{T6j}DAv`uxa?^G%1*x^B=}JX#} literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape-members.html b/docs/build/html/classmlx_1_1core_1_1_reshape-members.html new file mode 100644 index 000000000..77ec832ed --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reshape-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Reshape Member List
+
+
+ +

This is the complete list of members for mlx::core::Reshape, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reshapevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Reshapevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Reshapevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Reshapevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Reshapeinlinevirtual
Reshape(Stream stream, const std::vector< int > &shape)mlx::core::Reshapeinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Reshapevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Reshapevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape.html b/docs/build/html/classmlx_1_1core_1_1_reshape.html new file mode 100644 index 000000000..c495d9df9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_reshape.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Reshape Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Reshape Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Reshape:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Reshape (Stream stream, const std::vector< int > &shape)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Reshape()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Reshape::Reshape (Stream stream,
const std::vector< int > & shape )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reshape::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Reshape::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Reshape::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Reshape::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Reshape::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Reshape::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Reshape::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_reshape.png b/docs/build/html/classmlx_1_1core_1_1_reshape.png new file mode 100644 index 0000000000000000000000000000000000000000..1c30abb02f8afafc67a857ffd76cf9025c6a10e6 GIT binary patch literal 910 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVYo-U3d6^w7^UhI3Uz~ibf zzVg%m`j5*@6wM|aS$jz`>~Un@Bmrl>V>$5?i_d*-^qI6-FiPd6>K2bn-Okz*m(2h2 zN4humU3{;K{*R_>Cndgz9FNM3j@}sSZ{xKtITt+ZiYQAVhk0W zx(pwLq8J{iZecj!ox#{J{Q?t%{z6uUIAjH%*tZm&tF)WoS!4#Yb~iZs)d^m|MOF#B zp3I)S + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Round Member List
+
+
+ +

This is the complete list of members for mlx::core::Round, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Roundvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Roundvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Roundinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Roundvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Roundinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Roundinlinevirtual
Round(Stream stream)mlx::core::Roundinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Roundvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Roundvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_round.html b/docs/build/html/classmlx_1_1core_1_1_round.html new file mode 100644 index 000000000..61f17205c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_round.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Round Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Round Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Round:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Round (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Round()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Round::Round (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Round::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Round::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Round::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Round::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Round::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Round::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Round::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Round::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_round.png b/docs/build/html/classmlx_1_1core_1_1_round.png new file mode 100644 index 0000000000000000000000000000000000000000..b24499cb7318ebbba97ad65dbc2da147727e43fe GIT binary patch literal 881 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B+vJY5_^Dj46+eckuUfQL;z zUggR6`bWWD`kqw>H?*F;D3$7xAzS{+Vt4&X%jf?Tyo2&vrm0kN>Uw^9bir%V6ZO0H z)!FSolRYj8+bt;FbZqXc2<@})&HTEi1RFkEZZ zZBWnSx$0-zHP0g#54mceefD@sgu&ide@xrf=I>|WH@~&ReVggslC?FL1or&V_`EJE z`JPIy``%4^LTpNQZIlm>yx@)(-6zmmT8uXjZ z^HyKw;p=;LM?L*%vErIv{?#&X<$Vjb-u`rYMO5AM!t~ic=7m-2%*)i3j+(zAb>F@G zvWm53#mmaKtWtV2$HjBk9P$5bbuZ_iza1ujwd($>|3AZZJ@@-%cutyLw-x9$Yld}? z*e=M|X5Gz;t@^ssG~(-vt1H9ixG>nG1_bU-*YoD-6D2Vr*>vpmM8M|1znxliNb}Ud_0wpPaVs{M4w_>E`R^Uzzp&*^;Q)_kQgO zyZR|NKi8~q@zpK&uST72yQaR;ro{F1`ogoEJ72t?HuL6p4qano{e`J%&s4XZIkV@> fYM@i^e&m1pVNOldgTe~DWM4f6qmMw literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html b/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html new file mode 100644 index 000000000..034664345 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_s_v_d-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::SVD Member List
+
+
+ +

This is the complete list of members for mlx::core::SVD, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::SVDvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::SVDvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::SVDinlinevirtual
stream()mlx::core::Primitiveinline
SVD(Stream stream)mlx::core::SVDinlineexplicit
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::SVDvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d.html b/docs/build/html/classmlx_1_1core_1_1_s_v_d.html new file mode 100644 index 000000000..86fe0bf9d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_s_v_d.html @@ -0,0 +1,307 @@ + + + + + + + +MLX: mlx::core::SVD Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::SVD Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::SVD:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 SVD (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ SVD()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::SVD::SVD (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SVD::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SVD::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::SVD::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::SVD::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_s_v_d.png b/docs/build/html/classmlx_1_1core_1_1_s_v_d.png new file mode 100644 index 0000000000000000000000000000000000000000..428bbfa87a55b0f396c4a0078a822bf30cd49d9b GIT binary patch literal 520 zcmeAS@N?(olHy`uVBq!ia0vp^B|sd&!3-q-S1vRKQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=E8uJzX3_Dj46+P0Vjo;AvU! z{^{Aj|N8T+8k;P>?`rfucF1>Gu6$J9eL*LV#jOsLl1&$?-1Mzy$x``Qd-K=jQ?cLX zcdK~+mHFj*E^;Ny){1qxkA=RuJxb<2z4z&ZucyP5u16laByXqfd|g3Gdi9wv>mU8O zGd)c|wjE4$w|xO|#XesO)~JnZcpRBgSy&3O9q>Yt9$1 zzcPH~d~ry|zn0}!{~hav3>W-`xft%602#hm8RAC9(z-KLdat&41)2Y~eX;v_Wbe#N z)9%j>R$X~1-)Z}YsZ$rIyG8|cbKX9-^zHM${=G9FOD6B`)H}IetwdA&pxV!;a@jrA zqWAB#ZR2{HGdbk++;WC52~%x;{=V>}`?DXv+b_*620^vUO@A;{o!@I}(z~l9|JL*O xc~a*8{Z&`CtA$M3mFc9qa`wEYSu=UBv4?~t9-sSFI2IUZ44$rjF6*2UngEb^_ALMa literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_scan-members.html b/docs/build/html/classmlx_1_1core_1_1_scan-members.html new file mode 100644 index 000000000..55a51b51c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scan-members.html @@ -0,0 +1,120 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Scan Member List
+
+
+ +

This is the complete list of members for mlx::core::Scan, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scanvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scanvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Scanvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Scanvirtual
Max enum valuemlx::core::Scan
Min enum valuemlx::core::Scan
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Scaninlinevirtual
Prod enum valuemlx::core::Scan
ReduceType enum namemlx::core::Scan
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)mlx::core::Scaninlineexplicit
stream()mlx::core::Primitiveinline
Sum enum valuemlx::core::Scan
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Scanvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Scanvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scan.html b/docs/build/html/classmlx_1_1core_1_1_scan.html new file mode 100644 index 000000000..75ead8c56 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scan.html @@ -0,0 +1,483 @@ + + + + + + + +MLX: mlx::core::Scan Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Scan Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Scan:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType { Max +, Min +, Sum +, Prod + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Scan (Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + + +
enum mlx::core::Scan::ReduceType
+
+ + + + + +
Enumerator
Max 
Min 
Sum 
Prod 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Scan()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::Scan::Scan (Stream stream,
ReduceType reduce_type,
int axis,
bool reverse,
bool inclusive )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scan::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scan::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Scan::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scan::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Scan::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scan::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Scan::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scan.png b/docs/build/html/classmlx_1_1core_1_1_scan.png new file mode 100644 index 0000000000000000000000000000000000000000..6926bd27e6b7ffb21fd93cf477c978b21d4846e5 GIT binary patch literal 884 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B;FJzX3_Dj46+eVg}Lfya%X zKj`W2{g2FjKQp$JK-sh`9IqcJU4r3Y3^s=iA9>ns%9Sk_Up~T@8-o_w&z8*{S7)lDWp$5 z>br^8s~)TJpBZ6SiF(X+C2H}6Z+W!F+n7hPWvRVc*xr$ww``UR#B zLQxEHPTU8)GZ-s4bs6**vOZAV!mtNP@?S&PIm>?b$VqN+(-((LQn4@oaoO_azsslX z?%KDhT&;0_?Rn(p5nt`K&p!JaCPuyj1;N$)gO|I_-c8K)-)^J5--I=P|CGw|wU+sl z`WDZNw9~TNwdck0(DS?EwVq$Qs+Bw|%fr)sR@U63d3&$wJ--{(F1OBnQAk$!ge8w7 z%xliyzN&TpxaU@@y`i^rdw(>AZGXBwaBI!;;`3%7XN6Vj)Lpt|w#6>-_TPW|Za&<5 z>+H;T8KJH(=D2uPS?mAf{d38F{^>CJtGn{P{{ORj+LCy6U`Un63z$;7mSQQ3q_@+#l>B$DrNjG zoSIhF$e(rgq4-{zsF(A8%bDI6iI3{q824<$BD=pkG|w!2b#a!l>gLB9dY`X(jc(yt8M{6^6e4v2u~r$um)1qM`ggTe~DWM4f-?Fg> literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter-members.html b/docs/build/html/classmlx_1_1core_1_1_scatter-members.html new file mode 100644 index 000000000..f6db8360d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scatter-members.html @@ -0,0 +1,121 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Scatter Member List
+
+
+ +

This is the complete list of members for mlx::core::Scatter, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scattervirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Scattervirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Scattervirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Scattervirtual
Max enum valuemlx::core::Scatter
Min enum valuemlx::core::Scatter
None enum valuemlx::core::Scatter
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Scatterinlinevirtual
Prod enum valuemlx::core::Scatter
ReduceType enum namemlx::core::Scatter
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)mlx::core::Scatterinlineexplicit
stream()mlx::core::Primitiveinline
Sum enum valuemlx::core::Scatter
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Scattervirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter.html b/docs/build/html/classmlx_1_1core_1_1_scatter.html new file mode 100644 index 000000000..a46a7ff99 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_scatter.html @@ -0,0 +1,444 @@ + + + + + + + +MLX: mlx::core::Scatter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Scatter Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Scatter:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + +

+Public Types

enum  ReduceType {
+  Max +, Min +, Sum +, Prod +,
+  None +
+ }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Scatter (Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Member Enumeration Documentation

+ +

◆ ReduceType

+ +
+
+ + + + + + +
Enumerator
Max 
Min 
Sum 
Prod 
None 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ Scatter()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Scatter::Scatter (Stream stream,
ReduceType reduce_type,
const std::vector< int > & axes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scatter::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Scatter::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Scatter::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scatter::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Scatter::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Scatter::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_scatter.png b/docs/build/html/classmlx_1_1core_1_1_scatter.png new file mode 100644 index 0000000000000000000000000000000000000000..bac72d5cb6bd37cdfaea676b634f7946db599a0c GIT binary patch literal 901 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GTyo-U3d6^w7^zV3Uiz{A$g zzw*;>`;VTNV;Ux0TYD+C@0#f1O`U~m$8!EpDn9qvaoUoL9J-#LgrX)rIl6GF%1QrS z^;Ox+etuiPH)bd8aXS=!_)oMy{y#wxYf6xE|2n#4q2@-_3-l0 zc?%|I$=?38behWC2bWew?R&fSu-d+BTK%mmvX$=Hw=eus$@}aVIXR^Fm8iW{=cF&$ zEz?vgIiaphjafUVSdOn@^$y0K<1?bR-<}YcoaGy)T*bReR9*0Ze+FX(r!Is3Le>YW zTNw7Vh&4>V!1O^ViXqO4`+zq>@+p7D(>b3jCa7kbIC)Rn1G2#~OV#s}al3u&C%69yo&nn`SjMVs8pX|Ud@*t!D~0)SiAk!^_@|#FC6~rHpT0$ zw8~41YrOYAmu)@uxiDz0U;fpyXyttiujc;D4bi<{d1tflAD`7WV(YhT<6Y~2!|eWh zbL&0fyDcxvu3g13cd~}(^v}=!ACUj`@#l?KtE1*e|FyTjI_b-6kUu2DwV%&vXMUiS z$MD#2*|oQ3ORmp|ntZh*t~u+dC(8$6jIh8SP(PWYG?$vlsTgktrVwLe{e`J%&yKGA zy6*R1gPAi!eZu!t?ptiyE`Kd%+0<(xpO>82C6fGi*5R&SC6DF)Pk7~;mS!85;hP=4 zH6^U?(QCQjn5gw@%x70WRWjbJw|-S=O-NvGN}BFb=}_^$sI;{7nB}D}J}sI$?QV6C z>IbE@n-cw3m1g|3TIW~P8+P?mZvNJ+rPng^XU?4aaox3T3%`2adG`CSR)U>HKXZt# zk@0Imd%Lu>IH%KR&Uk0coC!=UKnG6_^U~x!ZvSN895YGV+(=-CVeoYIb6Mw<&;$T$ C_q0#| literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_select-members.html b/docs/build/html/classmlx_1_1core_1_1_select-members.html new file mode 100644 index 000000000..3b6c6ecf1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_select-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Select Member List
+
+
+ +

This is the complete list of members for mlx::core::Select, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Selectvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Selectvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Selectinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Selectvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Selectinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Selectinlinevirtual
Select(Stream stream)mlx::core::Selectinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Selectvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Selectvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_select.html b/docs/build/html/classmlx_1_1core_1_1_select.html new file mode 100644 index 000000000..40727d22d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_select.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Select Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Select Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Select:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Select (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Select()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Select::Select (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Select::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Select::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Select::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Select::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Select::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Select::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Select::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Select::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_select.png b/docs/build/html/classmlx_1_1core_1_1_select.png new file mode 100644 index 0000000000000000000000000000000000000000..86b98868b90c6824ae0df2da65743272b576a954 GIT binary patch literal 884 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B;FJzX3_Dj46+eb~3yfQL;z ze#xo#|BtY5aS(EP{q04Z-)pVKX$K9PALsm^RDABSa=B2L6;`(+8%Xt{`FU7KUZE`th*|7(y51+ zwfrx5Ztctc8{|F7uOKs2_x`tQiNf*OQ;$1M>f5t;Yi>rZ(DusZYdlxEeeIH;>*Dz< zy=9t8B`4ICsWEHk6wC27tlq)cb9_eB_S+NUlCylnl&g4GiK+`8@XuhZ;M8T%UE zbqm9u7O{ru7nnW>MKQ!VaUbwTNIn(Mcsl2E#RSz%6DRLUdq6f=W~q99+T3CvJNcQv z{%@20{GO|Ra;{YuOy@RVyJeRCDFdu7pU{2OFt^ioZ0J^F>@v`0*sCI z^VXKXeg61M%Gd6a`2{?U31kP8>TNn{I_0pL&50_s1av)T9 z=G4VMW+iudzn=OvE&XVG`0F1l7D=5v6V+vIb#Ygc@ZYG?>Gopx^}{xIZ#{Kue#q+! zpQ>Zt7S9d4`YAVmYu3_h<<+Y#isnv^klR)2ecnImCVyI*?#IW;Y0p%*oH+xGxS2Dj lUobNMY?Qfl%B6xojNu#Sbk(`8p9;($44$rjF6*2UngHxev^f9( literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html b/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html new file mode 100644 index 000000000..e598387d7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sigmoid-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sigmoid Member List
+
+
+ +

This is the complete list of members for mlx::core::Sigmoid, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sigmoidvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sigmoidvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sigmoidinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sigmoidvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sigmoidinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sigmoidinlinevirtual
Sigmoid(Stream stream)mlx::core::Sigmoidinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sigmoidvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sigmoidvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid.html b/docs/build/html/classmlx_1_1core_1_1_sigmoid.html new file mode 100644 index 000000000..66b45144b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sigmoid.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sigmoid Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Sigmoid Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sigmoid:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sigmoid (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sigmoid()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sigmoid::Sigmoid (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sigmoid::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sigmoid::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sigmoid::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sigmoid::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sigmoid::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sigmoid::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sigmoid::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sigmoid::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sigmoid.png b/docs/build/html/classmlx_1_1core_1_1_sigmoid.png new file mode 100644 index 0000000000000000000000000000000000000000..31bcd54a18ee3d376419e36e2c1dbbb72cd30266 GIT binary patch literal 906 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVoo-U3d6^w7^zV3Tvz{92< zukz%3{i9$neNU;bhYYi?N~OANFTLZLq+fs1^7%&v?@Jji(^M)sbv-{ly5KeGiF&O4 z-kq1f=6YOWwp&nYbbRisePL18&EiEgr=N}3|7=5T{kI3NCtDt!m(F$j-->ikt>c~7 zOnWC;Ej|}}FUo&q#ig^sX%m)A+k97Rv&n8nRoRu`z+nmz_XJ=_~wV* zbSy7>9dV6+tIA70y|t^J+`eTc^C)@kCS8@6Cyb+RH(r~&Ber~-spr)d(|!c0dtK@m zjGFXBb&JYL?+njL(+js2ovXB9xObE_!RFYNT`_H6q=K_INnOkiSo^|-!G0m@1Jx}I zds@UAre9$CAQZ(A=fr)$JA<);QNz#77)Z%lG9j7h1$N}`b5YUFFE2gTPJTFsc z1qy>VpfCtJYre%#>7DPLEtj(F;zfj~EnOF1d}H4)-HNIU$L?-=`*K?K)7i@&yzMu7 zzs~fz@1?>k#;2o?dQWHnyW>^dci*RXsz*(y_L^2M&4!R zm30c9$2ZFTn74g(=(po@@4mVcmb+c_k#N-AN%!4CKh3>9SN2hMRBE)%rCWSk-fehW zRex{Jhkdtw{mZwk@_kXVNM-L__5bX3FONS@UEQCxH}7k8-73|W$3gLEv@YzNB|rOt z-8&e1j?ajietJS&a+YtHaux3?(YYKC)QJflV9Q|>HO;5r8S|!lDbE`W{Gamt_rBhzZA7kIEI4UIl-;OW;;Viw`2T9*vuExd(ANL4 z;aVl@%$Z+*e#zU)o0azLf%*TC*)waWUobNM3=9h|P2S`7yk}?cU-r#@7BJT^c)I$z JtaD0e0swQ2z-j;h literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_sign-members.html b/docs/build/html/classmlx_1_1core_1_1_sign-members.html new file mode 100644 index 000000000..0bb890485 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sign-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sign Member List
+
+
+ +

This is the complete list of members for mlx::core::Sign, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Signvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Signvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Signinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Signvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Signinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Signinlinevirtual
Sign(Stream stream)mlx::core::Signinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Signvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Signvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sign.html b/docs/build/html/classmlx_1_1core_1_1_sign.html new file mode 100644 index 000000000..bfa86de85 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sign.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sign Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Sign Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sign:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sign (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sign()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sign::Sign (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sign::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sign::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sign::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sign::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sign::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sign::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sign::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sign::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sign.png b/docs/build/html/classmlx_1_1core_1_1_sign.png new file mode 100644 index 0000000000000000000000000000000000000000..1489dbc9cae5c9fc8487cd6c2bbc13582559764e GIT binary patch literal 890 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVyo-U3d6^w7^J`{SS#pA{= zZ|ryPe`omWr4E_BeY@`LF0=^FQM7GYP=C_$`A-G!OF1pmR4O@jJwH9V;5F%q`dz!l zJFoxS>v>6d--1%3)v$uDi-})+9(X%@4qL%cnc}FhEFERD>&7St-n&*@y``t1; zCr!UFNkxC5s%PA@tEViVzhw9!BHr+GM)2#RnU}Kt)~>$V6E`{Qs3*$@;V6bUC+-8@ z8H^R2x(xaYSs$owVc63m)-e47(+44h!KMz+mA z=ehm7#;>jQ`ID}E;$Ayh&|SDYEb8oZKhuXBuYBFab2}`aEnfc1<98Rg74%nsoxQI4 z^mMCt%eEfzy|izu@$2ZT-tzp%9?z;iY<{))cFxu0T{#|})w^S9zU3^Ja?q)>CcU)Hi=vUorEh;s5UaS;m{^3xHht?5daM(w9G&>uct!tZPj;0?Z{0p00i_ I>zopr00K6<`~Uy| literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_sin-members.html b/docs/build/html/classmlx_1_1core_1_1_sin-members.html new file mode 100644 index 000000000..44d282d97 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sin-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sin Member List
+
+
+ +

This is the complete list of members for mlx::core::Sin, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sininlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sinvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sininlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sininlinevirtual
Sin(Stream stream)mlx::core::Sininlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sinvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sinvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sin.html b/docs/build/html/classmlx_1_1core_1_1_sin.html new file mode 100644 index 000000000..d4d021f32 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sin.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sin Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Sin Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sin:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sin (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sin()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sin::Sin (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sin::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sin::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sin::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sin::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sin::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sin::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sin::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sin::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sin.png b/docs/build/html/classmlx_1_1core_1_1_sin.png new file mode 100644 index 0000000000000000000000000000000000000000..a532b6c8d31d753c5cce38bc268cc0c5dc6416ec GIT binary patch literal 864 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-KJzX3_Dj46+ec1O(LBNe) zzH{dH`bW{(-V2vnnXW#4w(p_K5(8VF!|&pY&Q;o+m}DZgP}MWe$$Qcsi!GB>etLfQ z&wqQR-o&%&Ap4eiYYN+~u6w<{_smbU$n)4QQ~%qaHv8_gpXr=axGeLj`);qGyYn*c z9@bcS`PPlE7plCjPR^TS)ui&aa@ty}Q+m=tmnLr8_HEPs>#xFouDli`c2>)?^mFD_ z^A{@7$2ZljRP|Kfv1Qhz``52M6o}8hdfjPKU(7`<_FMOkTzbF6)YCV6+mCC$QoVy}**o zHRXq*x0RRgT6|SyVX?64%i}wLn1As8UvqEUuNBvRhy6dlKg;u1DJUL$R!=QHcbM@* z@EwL@$v$gu=X$I++^V)(a4*}{uGt(9)CmO*G~7P1Zwd0swDCwwTPzrrmKNu9`plWr zrE5=rUVi>*s{Hj?%j4|?w}ECp!=3-TJLrKtq7}?P`_z4<QU*L Qz`VfV>FVdQ&MBb@005o0xc~qF literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh-members.html b/docs/build/html/classmlx_1_1core_1_1_sinh-members.html new file mode 100644 index 000000000..73748f4b1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sinh-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sinh Member List
+
+
+ +

This is the complete list of members for mlx::core::Sinh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sinhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sinhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sinhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sinhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sinhinlinevirtual
Sinh(Stream stream)mlx::core::Sinhinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sinhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sinhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh.html b/docs/build/html/classmlx_1_1core_1_1_sinh.html new file mode 100644 index 000000000..dbef15193 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sinh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Sinh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Sinh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sinh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sinh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sinh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Sinh::Sinh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sinh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sinh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sinh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sinh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sinh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sinh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sinh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sinh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sinh.png b/docs/build/html/classmlx_1_1core_1_1_sinh.png new file mode 100644 index 0000000000000000000000000000000000000000..dcfa33426314ef66a6ffe30706d9f8c1bcb34e96 GIT binary patch literal 870 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B*kJzX3_Dj46+ec1O(LBNe) zzH{dH`bVp`1})fPEmgYnrs13p)@AL9H~xB`vpmlpIVnwfi^@sw49`j8?O~Ht>|cJ9 zT)X$%d*4ZQ516CWv+peI+j=QE-+MXNPPNA$xBA@vwApuGua4q2i|d(B-FJHh-Syj& zmpH{FBlpd(i;t(QnwU7ps!8SRoYd8Qmr|X*Ze6|>rvJNkZQ@kjTv-40~F{8m3=h`XCg=5QijL@1XsB&hus+&&`4`ljF2KC(Td&^>S9p@18~R zzrU+_>fXO#HM4NbhK!}{KQFBmTN7PdqkDPV-RUQc<+a-&)pn){^ce`)weGFi8F^3FX_eH*x9y-g5Ol zsk?VyW?rkt**sgrGynO?{f+Uzo`2qS)qQK+_Fwh?LcK4o7Y6zzHE!*k;(2@x<%Nth z`YpHWP1oG}?8>gFvoB0nhTU;tu*VxLNRsvorUrR^b~G}c-Fv~v__I({T3Xz^%T;CE z+pVwad}jNx;+oOk$j@PGOSsGLm;GK=ZFK6o(c63ZD$}lJ-&>VvyqUQs%6m)7w*0Ht zKOK8}@$t!D3AH=;X3U)X@XWP;p`m7zW_}H^mU?^O%$uI=*=s}XBK`!f{UriS8QMS3 zXPc_tzaF;!Tj|=O1#f+I{^tH^-kO^BE!OGunKRxQGiO$E>KYsCf6iJu + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Slice Member List
+
+
+ +

This is the complete list of members for mlx::core::Slice, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Slicevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Slicevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Slicevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Slicevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sliceinlinevirtual
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)mlx::core::Sliceinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Slicevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Slicevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice.html b/docs/build/html/classmlx_1_1core_1_1_slice.html new file mode 100644 index 000000000..6f6c64555 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_slice.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::Slice Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Slice Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Slice:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Slice (Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Slice()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::Slice::Slice (Stream stream,
const std::vector< int > & start_indices,
const std::vector< int > & end_indices,
const std::vector< int > & strides )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Slice::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Slice::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Slice::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Slice::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Slice::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Slice::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Slice::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice.png b/docs/build/html/classmlx_1_1core_1_1_slice.png new file mode 100644 index 0000000000000000000000000000000000000000..965022c8191a02cbad431215834a7a7780b45ed6 GIT binary patch literal 884 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B;FJzX3_Dj46+eK={g0*@QN zyz$&U|AphTS2Hb=w$7byJ)1Ax%k+pd&+ht@md`&bcwfqBnWj?7sq6Xa(FLzbPt@<) z@7;R&kD=!vXq&KTuf3Az?FHdNG=Z6ZGh5a3lX3g}b)Mh-^?#ex^LVcM z*>=tI$i+jU+Pcp^FNyeY=hY|E_T1g^&C8{8?`5MHSyudF#G+ z)1DBUFO`es!((6Dt&E8dU)__rN>TM~=BiSW-QU7?&n?e>TeYK3aB0x)W|g zCwA?p8kvyQm!r4dz11PUWtZ;$zEw}H`>m@x*Jn;W{z!f8*Vb$5Idk*B&of literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html b/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html new file mode 100644 index 000000000..f4abb0ccf --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_slice_update-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::SliceUpdate Member List
+
+
+ +

This is the complete list of members for mlx::core::SliceUpdate, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::SliceUpdatevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::SliceUpdatevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::SliceUpdatevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::SliceUpdatevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::SliceUpdateinlinevirtual
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)mlx::core::SliceUpdateinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::SliceUpdatevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::SliceUpdatevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update.html b/docs/build/html/classmlx_1_1core_1_1_slice_update.html new file mode 100644 index 000000000..b2014c218 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_slice_update.html @@ -0,0 +1,447 @@ + + + + + + + +MLX: mlx::core::SliceUpdate Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::SliceUpdate Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::SliceUpdate:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 SliceUpdate (Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ SliceUpdate()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::SliceUpdate::SliceUpdate (Stream stream,
const std::vector< int > & start_indices,
const std::vector< int > & end_indices,
const std::vector< int > & strides )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SliceUpdate::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::SliceUpdate::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::SliceUpdate::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::SliceUpdate::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::SliceUpdate::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::SliceUpdate::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::SliceUpdate::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_slice_update.png b/docs/build/html/classmlx_1_1core_1_1_slice_update.png new file mode 100644 index 0000000000000000000000000000000000000000..25254654e697a978d4225a1dc99bb0259c96b348 GIT binary patch literal 918 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVQo-U3d6^w7^KAiMPfyb49 zx}^I1|3~82ZgW^DC42Yj?B2t7m3$8=s=SUbI#+2oQ8m+yb(%`0XvCx^Q<=3rC&j<~ zCHYhD^X&;s)PFHuW0o&pQM_i$+}yb#itkU}kXs#j(%*33zDVaO7Tq`hE&Vzvq@O+F zx`CHek6GE?Eq8;yrb#UCQ}A@RJrySVbONvD%JArD_Z#z%ONIT^ytY?IO5gMAhs3M0 zC6m_fTQfh&^HOI_gzriD?A5!EJdY@i_6S-ZzTq2pu65yj{!6o#q?x$atQJ;%d0hnP zxX=X8N#TvEo^fZprxu^P%=sXA7sHw5mK*h^Ywmn@Bx*(4lUX9#R)PoCA7J{RwSi#| zmkz^vChh~PadZDMgq^cIuf9p;?6L!sRMsaw_EG1n zR^IyZFDd8x8s73%wYPJwnziS6c<$x9X|?%o_Uno{{>i1Lcj8=nOKcn`{n{8?(P@2T z)$hkD(apEAUVd|`Va$&B^gF|=Zo}Qh-ZlKI^SaX`ez$Iu-FE);@58*)bL6Y48m_7& zUg(_krEu>L`w#7PHTRbN(z^CLJT)<-!M+moU!oFHVZp3obg_ zxi#SI2~*9m5~YUtM%D+R35*q@5e#vxVh!QAxqlhgObMFhuVTFUrbt9uS{&=NnKQ$e zN36dTcl-IXYnAJ`-)8+ZyZ1ae;`TM=?feN#b-!7y+>>|mhWzTAx%XV|@5z4Mv}di@ z>`%wfoO-b#?0(9%{#P51>3iSy+I_($@A!@T-`}N$_1Z=M4GT4!G*dQo?us%!v9GM# zMdAp60dM z_-=ENs`|CtyVnWld!M(yc;-!XRpRmkvqVpyIkWPBk@07(4QI~miP{k4m1*^t`Iz-A VO@F-`3cwu1;OXk;vd$@?2>=6_y-ffB literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax-members.html b/docs/build/html/classmlx_1_1core_1_1_softmax-members.html new file mode 100644 index 000000000..c3e37e9ff --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_softmax-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Softmax Member List
+
+
+ +

This is the complete list of members for mlx::core::Softmax, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Softmaxvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Softmaxvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Softmaxvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Softmaxvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Softmaxinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Softmaxinlinevirtual
Softmax(Stream stream, bool precise)mlx::core::Softmaxinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Softmaxvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Softmaxvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax.html b/docs/build/html/classmlx_1_1core_1_1_softmax.html new file mode 100644 index 000000000..5a96e8cbb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_softmax.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Softmax Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Softmax Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Softmax:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Softmax (Stream stream, bool precise)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Softmax()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Softmax::Softmax (Stream stream,
bool precise )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Softmax::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Softmax::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Softmax::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Softmax::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Softmax::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Softmax::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Softmax::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Softmax::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_softmax.png b/docs/build/html/classmlx_1_1core_1_1_softmax.png new file mode 100644 index 0000000000000000000000000000000000000000..643fb0d3d4f172977271cbb763d23f17e0f6a592 GIT binary patch literal 894 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVio-U3d6^w7^<`zFz;9;BI zKC|cFfBk(oK5{3#(X!nB^x2fe?4;&B92aX;pU*kotmC;@Fly2h)h#L~yPdT?C(Sph ztty`XH+=FE{U2;=yDi@ZJdetZy?sa1N$+HY{^t$0_P-yzJSmgZ|5;{R?aFjdt=p~F zMBF^*h8@?-U%NZR=H)R>WsjiKq1#Sr_?j(QqFkE0b(8)2t7|_6uU)FMDss}PXP347 zFL-Y2OZ~fan#$YB`(%keeru3+dnJ|n9C_=LFREZeniU!+2`{gfK!Utszm6vYtd z#C^a!gRz2BmqC9a>jTv-40~F{8m1#8PuDM)T72%YI25==9Yxms6a*tS+Uz&Uq6S{kG`cHLEKHdj*#U z?QT+$y}j#U^__KLm3DrcuC82rD^L4~`IcR(`?s!oV!eL%tt02RoJ#&U6%-(xw}1Wq zoYww5$2@)4;!xFv#X_nt%kKS=|KV(JCm&tA`r7lb|L6R-s{FNDsOlMa?CL4Y=MNcn zbh96*=vn!-g6E~_vTK*KT-G~oeWb$l1810^1=t_+sQh7<{_5^ZQ~gh(Tk0-P>iKHKul*^wfA+sKYc5W^J^l6GE3+O& zZt~OCyFBC3wW@Ug%T=0-M3c{+>5EB!{bN;9@5wV!U81sEye6wHni+XCHp+Ku`KGA0 zs>u5bLYMtfiEf`{e{ol#@M@;q?A4*gnSWhR%5V7=yTb3$we{th`y*xNKFz6HlbZI; wyJgzUnbR*A8GjavN=pNJ-%E4p!ynB0duHFhm)iIdm{S-$UHx3vIVCg!09NU_-T(jq literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_sort-members.html b/docs/build/html/classmlx_1_1core_1_1_sort-members.html new file mode 100644 index 000000000..67dcaa64c --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sort-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sort Member List
+
+
+ +

This is the complete list of members for mlx::core::Sort, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sortvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sortvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sortvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sortvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sortinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sortinlinevirtual
Sort(Stream stream, int axis)mlx::core::Sortinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sortvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sortvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sort.html b/docs/build/html/classmlx_1_1core_1_1_sort.html new file mode 100644 index 000000000..197d6fdce --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sort.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Sort Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Sort Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sort:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sort (Stream stream, int axis)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sort()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Sort::Sort (Stream stream,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sort::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sort::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sort::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sort::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sort::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sort::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sort::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sort::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sort.png b/docs/build/html/classmlx_1_1core_1_1_sort.png new file mode 100644 index 0000000000000000000000000000000000000000..fa624d11082dcef9d7f30c9508a9e6a59d44ea82 GIT binary patch literal 870 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B*kJzX3_Dj46+eK={g0*@QN zyz$&U|Ap_r3S)JdU6wQ5dNyCWm+6sYj@|VqEuVi>@V=DMGEJqDQ`hs;qYGY>o~ZlI zzjy2RpLCB)!gdQvHyxY%DneWLy_tX4l>Q{0_~IL3^K~nB885%_>CWM6@?W()pITg; zbxC8T@vA#em!?kdvc7AR%I2A6xjJm-(v@vfwk&_QHtgH)sJnXhSE{zSZuJVXtciNc zX|>|ozP0nSJTEcotzGr>_PtxVkBZlRitxBJFMP{2-nIUZtlm#C^}Mp+_9^tET2DQ*wH0@prU7G>2BVarpvC~&2m}qxb=|=(+^Hv2K|Ms z4^+1>>}e5en0|rjgHRMhoD=r}?+nHYB*{OFYv&Zp@o!S;g_~ZzR>iaK*@w%fnZLel zy8qnXc~aKCg{4A;TMDjvb^o-{6tjsd{W-g4Bp5GsQb)Q*& zdDdsOO#4?;pDum0uwVM`j#p9NJ)hp%6_x4}%&VE{AG~()jjh{nUEdj%ef{uPw<%t4 zrByQRr}5tZRJQfh=fZ%se)(6+vX%EOxSIR3G(`7)<(->q`<92wb*D$YZo8(QGdKVH zY}tFO?#_AH7Z%!b&U4C=)0H3dA58!Eb5HKqz_s7k{+nNab<&sDptzI_*M2^yo%w-Q z9>Zh9W!K)CExA4;s(iKJUbd@UlQ|x!6ABt=xP4;Z66BR>`T(aWK8kuz>^RLc2Jlj+}YpLc<5exsWE8qkY5ja0KW_?Xsny%kk&H1XYdmd$X zY3k{QPn+%9F!QG7y{xzLp&`mMBfUSmrJc+=bH;CHT=sU~R7?Biq3;emZY|UOUkOeU zS@Zvtt}S(bJ$c8u?Z0ymw5ne`^JaC+w3#!fUobNMECh1Uv8!I1OCSDV7u-L$aIKu~ Rb6}2O@O1TaS?83{1OWA;uay7* literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_split-members.html b/docs/build/html/classmlx_1_1core_1_1_split-members.html new file mode 100644 index 000000000..9c7cd09dd --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_split-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Split Member List
+
+
+ +

This is the complete list of members for mlx::core::Split, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Splitvirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::Splitvirtual
is_equivalent(const Primitive &other) const overridemlx::core::Splitvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Splitvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Splitinlinevirtual
Split(Stream stream, const std::vector< int > &indices, int axis)mlx::core::Splitinlineexplicit
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Splitvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Splitvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_split.html b/docs/build/html/classmlx_1_1core_1_1_split.html new file mode 100644 index 000000000..2889694e8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_split.html @@ -0,0 +1,426 @@ + + + + + + + +MLX: mlx::core::Split Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Split Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Split:
+
+
+ + +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Split (Stream stream, const std::vector< int > &indices, int axis)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Split()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::Split::Split (Stream stream,
const std::vector< int > & indices,
int axis )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Split::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Split::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Split::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Split::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Split::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Split::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Split::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_split.png b/docs/build/html/classmlx_1_1core_1_1_split.png new file mode 100644 index 0000000000000000000000000000000000000000..5b7fd768bd10378cc696f15ae67363ff8e630223 GIT binary patch literal 527 zcmeAS@N?(olHy`uVBq!ia0vp^B|sd&!3-q-S1vRKQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=E93JY5_^Dj46+?d@w(;AvUz zylLk5`bWiqjau{G@mHMO`0~iQEj1#~83mj;_O?1qsx@7pa?_`tB}=6;_S>JURm*Gm zCVlxO^UM4E#W2pS*E^GfOmv#p^hZofKYjH4)E%z((zU$nPp#vr6}uS0c|USr)cf@j zE8hjbd!VkHZ8mwrvAuuH@5(*BePLBO-@?=E+PhX=xKzqM`Q6zfuC-E{t;UzW&3?D; zyG+!qw>-1f-P&#bbgH)DrB>+$kK!i1S<0w-GEFLAQqHp73lv)}L^6C>CdY6gy4*p& zUwTEkgL+#1E8PqAPwp|YHk3&PFk~zPGJxu_Ff1m{jLbi%vT{9J$fQ{>`JLizFI=9v zuA0wNZy#6h)1B$();ZmdKI;Vst E0I`_&mjD0& literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html b/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html new file mode 100644 index 000000000..4a2e605ce --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sqrt-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Sqrt Member List
+
+
+ +

This is the complete list of members for mlx::core::Sqrt, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sqrtvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Sqrtvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Sqrtvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Sqrtvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Sqrtinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Sqrtinlinevirtual
Sqrt(Stream stream, bool recip=false)mlx::core::Sqrtinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Sqrtvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Sqrtvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt.html b/docs/build/html/classmlx_1_1core_1_1_sqrt.html new file mode 100644 index 000000000..c178fcb3d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_sqrt.html @@ -0,0 +1,467 @@ + + + + + + + +MLX: mlx::core::Sqrt Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Sqrt Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Sqrt:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Sqrt (Stream stream, bool recip=false)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
void print (std::ostream &os) override
 Print the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Sqrt()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Sqrt::Sqrt (Stream stream,
bool recip = false )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sqrt::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Sqrt::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Sqrt::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sqrt::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Sqrt::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Sqrt::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Sqrt::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Sqrt::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_sqrt.png b/docs/build/html/classmlx_1_1core_1_1_sqrt.png new file mode 100644 index 0000000000000000000000000000000000000000..f30bd2b340d1865b74d6e6fdcac5e73e736c2132 GIT binary patch literal 887 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUno-U3d6^w7^zFxQ3fQL;z ze#xo#|BrCX=y7shS$k`9aO1YD^a%>%>t^oTU*2) z8?rk0_424mC5_&zvwjxmn|A+*S$#)r(v?qL*Y+*=>UZb7b*gWW>8kLWY;~_o@yc6N zPI_l}PMUsUl8XN0tfzB67ckUxu^*`DS^2es=cVbgYbmCSt}lq%D9re$MXX`^1*Q)| zQ4DcT+y}ff7%Mn+8T1#jK2Y7lum?%fq+;*=0#Hs!ljYillv?5WrMEEnZ{ zPi5=kd698ia=Y%mI39X^m%rBYYge_B{jN@!6ygvn2xAO9jh;P}YyT5lnmE?KBt}eL9h+O zv1FgM$CEwQ8*WuwEx4EMYS&zj2kNK+fisjo$!9E`a;aj%nKN$6Th5#TCXtylr-yHi ztL>}|pBZTwBe`$Q)|TtVb}Oeo%PD?R6!avnd!y{M4U6nt@A$WEt4d4bjazHEQ1x}s zBZK22)3<5AZsj%JoW5iARl8Ly+>ACGAC)%!m6o=>#&LDz@@=bttUU`u^^Ul1oi*+M zzObvGp5L=tEgI!|f9A})pSM9%MgQ;DNBy{G%>4TEX#CG`IgzO5teG<_IdzSV^%th5 iJxjghrMdLs59ZKabN!~@-6sIdA`G6celF{r5}E+9kG7}) literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_square-members.html b/docs/build/html/classmlx_1_1core_1_1_square-members.html new file mode 100644 index 000000000..a447d33e6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_square-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Square Member List
+
+
+ +

This is the complete list of members for mlx::core::Square, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Squarevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Squarevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Squareinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Squarevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Squareinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Squareinlinevirtual
Square(Stream stream)mlx::core::Squareinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Squarevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Squarevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_square.html b/docs/build/html/classmlx_1_1core_1_1_square.html new file mode 100644 index 000000000..24c7cd1c6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_square.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Square Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Square Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Square:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Square (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Square()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Square::Square (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Square::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Square::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Square::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Square::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Square::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Square::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Square::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Square::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_square.png b/docs/build/html/classmlx_1_1core_1_1_square.png new file mode 100644 index 0000000000000000000000000000000000000000..06ae832b851a9a36af05257a13d4b81c8020b4da GIT binary patch literal 906 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVoo-U3d6^w7^KFnKfz~kn> z&tvNS|DF1&0zw>D-tN-7Tsn1e+JhO*kJr?nw0!GWk1lvkdZK>U ze(%=Hf3iI;G1ob~PU^M3x-%^5e(t;w&FM*N;)`#D&DZ@=m+rUm>C8i?Tgw%t#n(RlTj^ndQb2*71PC)N2mKOQu!;j zP}MWe$$QcskV_@QwV%&vXMPZShv8VV&)VC$9_tOas;w5>%XYPEGRFh;Eev~F#2Thw zVEP~w#SrJjeZV_|v4T^VL4P6Z1673NC;5!0b3Rv0P|Y-fnQWP*>iNmIJ${|%cYpog zCjX@-UHQbhw!33EAon=S1&%8ufRH=8ivs&+l$| zw#U<|uPn88Wz~zK#q!rP>N4zUc{nnW^rx$0*%7=z3&%8PBRqrdO`D+Y}Uw^b)>pb23_x`pyZnGmz z|3sY*Ubm_= + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::StopGradient Member List
+
+
+ +

This is the complete list of members for mlx::core::StopGradient, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::StopGradientvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::StopGradientvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::StopGradientinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::StopGradientinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::StopGradientinlinevirtual
StopGradient(Stream stream)mlx::core::StopGradientinlineexplicit
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::StopGradientvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html new file mode 100644 index 000000000..2913a73b7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.html @@ -0,0 +1,382 @@ + + + + + + + +MLX: mlx::core::StopGradient Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::StopGradient Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::StopGradient:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 StopGradient (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ StopGradient()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::StopGradient::StopGradient (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::StopGradient::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::StopGradient::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::StopGradient::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::StopGradient::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::StopGradient::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::StopGradient::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_stop_gradient.png b/docs/build/html/classmlx_1_1core_1_1_stop_gradient.png new file mode 100644 index 0000000000000000000000000000000000000000..082cc974a4636d84354974c7330b916159b316d1 GIT binary patch literal 934 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GVMo-U3d6^w7^KHT(5fv079 z>axk->mNmDFXX%+wO8)xna;yeT#7Xb9^?$|)7qeY>=iZRU;B&%(HWSFTB%Hb3cYW@6aHD`(r6YQIp4 zs*SQgqI6Pt=LVls+jm}#eUz*dE2eVuwd9(rh1-O(>jkd4%u4I~5u`fl$#szkm6xFj z6PJWHx=xBa+dZ}T++|LNdrMgvo?V_3aoR}r`kV;ut35H3lU!Aq8sZyS8RA&Q7%D^~ z7#@TsFgAqa6}T#Y;FRU_pDjL+Hfe>JYX4%z>F{VbLJ>p&iOb`JrO!+IuehCN(53?H;MFdSHoSKt+Y!qO?1YMg`wXZkj(Dm@BK=hGS88Qc0w4_7_VjhOv2W4V>G(lXmyR}z*!s(!6;`0oi|P_5K>vp09ek+q!j z<2Kh#ubr}ybM?_%Jd-@`UwQm)3Exz~J-vM=>mpl(ua;SaZ=P*$ z>ZivmT;>`*_41B;?X_qBEAKlCckQ_yS^p#aU;SLRX&oKm`(Le(a(S|vOGi);7*bxE dOJ9EBcS)YPleJR&ATUQUc)I$ztaD0e0swe`zn}mB literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract-members.html b/docs/build/html/classmlx_1_1core_1_1_subtract-members.html new file mode 100644 index 000000000..b6f161513 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_subtract-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Subtract Member List
+
+
+ +

This is the complete list of members for mlx::core::Subtract, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Subtractvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Subtractvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Subtractinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Subtractvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Subtractinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Subtractinlinevirtual
stream()mlx::core::Primitiveinline
Subtract(Stream stream)mlx::core::Subtractinlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Subtractvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Subtractvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract.html b/docs/build/html/classmlx_1_1core_1_1_subtract.html new file mode 100644 index 000000000..d41645a17 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_subtract.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Subtract Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Subtract Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Subtract:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Subtract (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Subtract()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Subtract::Subtract (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Subtract::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Subtract::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Subtract::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Subtract::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Subtract::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Subtract::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Subtract::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Subtract::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_subtract.png b/docs/build/html/classmlx_1_1core_1_1_subtract.png new file mode 100644 index 0000000000000000000000000000000000000000..9a227b3b310114e9007a93c332bf6a8201ca50a1 GIT binary patch literal 903 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GUdo-U3d6^w7^KI~g<#lxl^ zzvR^W|3|h5aB#A``c_a@TpE*lMMTEf;@rFVqH~otCnlLlEmZZ4bMl_F$70JQm7kuM zmv8@OS+3SxuBLO(dv`{2o6Y$RpBBfeu9Opx z3YJnWb(_6=-3-#`pr;hFRM%6uEl+OzI9h@-KwoyTsM0K+15mT z<+NIPZR6VeS)P|T_0}$X`h9QdjU(TqzHXebM0tM3RNh?wq?zTXW-UoGS^wjz=aeP! z%3D-UdS`e}ntoxDivHuQr*l3RFx)xHoM3b8N|jC97pdUanWl@bFNi7>V*Jx0)-e47 z(+8m_hBzng1Kt^o6`Z;Z`U_bfsBU4{gCzNnaqXO9IsQ#5y>2j*-Pfvk);;@h+BEam zmreH1*DHGJ+Ak~>EZnvsW9ikJc_G}>!ejS8+wk_S_lM8N?A}ey_1+${dj1z4J^NRC zW?!3Q?s@#;x;6Tt{IPjo9?#0Yt9~{6cGgwH-B})YG($kMDXIeQ#Y@t)1Vht1s7HTW9i!J?ic>{l%F-&plpe_R)InuZTUDwWM#&+hBD6 z-TShN)!*Lvly8}(RN}Wt}bvv&E_j=1Kmqn<1ugb4)=G~9l&Zwd0swDCwwTPzrrmKNu9`plVb z=hwmSUnK{ZF$ym@{Pr+_1B$GFa7GZ^y(u+ zRjv0HhhG1^B&vPd`|Pc|zo$p}p8h_6^X0g(cU7x*Z%s>Mx4Q@mk|V$%(HGx7IpS6= zpKkky?#UTx-?Ce#&73*?f|2oOA&`TQUG>sj`tS$8+1J@S_$6QL2WA-tPgg&ebxsLQ E0JKQK7XSbN literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_tan-members.html b/docs/build/html/classmlx_1_1core_1_1_tan-members.html new file mode 100644 index 000000000..92188e869 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_tan-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Tan Member List
+
+
+ +

This is the complete list of members for mlx::core::Tan, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Taninlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Tanvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Taninlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Taninlinevirtual
stream()mlx::core::Primitiveinline
Tan(Stream stream)mlx::core::Taninlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Tanvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Tanvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tan.html b/docs/build/html/classmlx_1_1core_1_1_tan.html new file mode 100644 index 000000000..b8be27324 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_tan.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Tan Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Tan Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Tan:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Tan (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Tan()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Tan::Tan (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tan::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tan::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Tan::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tan::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Tan::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Tan::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tan::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Tan::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tan.png b/docs/build/html/classmlx_1_1core_1_1_tan.png new file mode 100644 index 0000000000000000000000000000000000000000..613c47aec160ab22a445a9fd7be7f45f490b37d1 GIT binary patch literal 875 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-)JY5_^Dj46+ec1O(LBNe) zzH{dH`bW{La~m(cyRzc+*}jJ^OAKsz4!?^pI#+3PVv>o}LRHT=C+|snEVfKi`RV!H zKmYBOdK1rIjrB+V!Tb8d|8}{vY)Lkw6D^=56w|WKLtciNc zYjxtT!v2!35AeJR^#?a3^c^^RK~sWJWF)Me0L$ofEa z3&Wlkv4-gvm_7(aF~m7>AMnm#tU!|d!?<=%u^j&z=(_o_*=>Ug{xJXG{l6w|+piVZUWffZzdy_KS1BkSdsa^^K6jY$gI*rPW5Z?F z-kL4BJ|k-K)sDF4tfRgxAB3@n2WDXXWRB8YY9gm%y!j%BuCcNH!ql{9Pgh+#db7WO z=1o(()n89FX-B=3`)61DPDH=@SrRa4LSI&Xkx#yyynJbRcHfLM-`4(^b$=Icu5`q) zX^zI5gYVqbvfZY7KF( + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Tanh Member List
+
+
+ +

This is the complete list of members for mlx::core::Tanh, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanhvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Tanhvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Tanhinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Tanhvirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs) overridemlx::core::Tanhinlinevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Tanhinlinevirtual
stream()mlx::core::Primitiveinline
Tanh(Stream stream)mlx::core::Tanhinlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Tanhvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Tanhvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tanh.html b/docs/build/html/classmlx_1_1core_1_1_tanh.html new file mode 100644 index 000000000..d89fd4ce7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_tanh.html @@ -0,0 +1,463 @@ + + + + + + + +MLX: mlx::core::Tanh Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Tanh Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Tanh:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Tanh (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs) override
 Get the output shapes of the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Tanh()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Tanh::Tanh (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tanh::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Tanh::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Tanh::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tanh::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ output_shapes()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< std::vector< int > > mlx::core::Tanh::output_shapes (const std::vector< array > & inputs)
+
+inlineoverridevirtual
+
+ +

Get the output shapes of the primitive.

+

This is not required to be implemented by derived classes, in which case it will throw.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Tanh::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Tanh::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Tanh::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_tanh.png b/docs/build/html/classmlx_1_1core_1_1_tanh.png new file mode 100644 index 0000000000000000000000000000000000000000..8e330c32fba7d8c3f5289e90c42b1130d81805b1 GIT binary patch literal 879 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-aJY5_^Dj46+eVzAOL4d8D zU-Q=Y`bWvu`xuVBRPzez+a;@$AF3bF_|5eK~wR(YR zdtlJ=s~b-*d^@Gf{H{eRi)U8ll(j})lX*2?`ou=Bee-?muF$%bTc^2h^$NOK6ZMpL z*NJO8*VJE`w1hh*D)i*`t)*5+=5IZft>~FP@8T`~E&DpF_fNj261w)*BX9G-Np-y! zCaLHzRP~GlxwL2X)Z%l889(UdF+4V0cI~a%lIt^~CSUD{YtB0A%kn`uiXqO4`+#=_ zV+E%!gZ@I+2dY~b_OysKOuxYNK?ouFlR3)rxs6NP35cV3oVtHn8YwuWc^^OO4<<9|K>yy>d@*0}Az>i>m$Us^95HR(xe%-T7{a(oTt zh0GHDhFkNtEABmZW!4(EFH#}ec}flQQ3C>JD4mwS;HA0rqr%LYy>1yZXI66R8XN2T zXWDLW{^~a~@@9>p-HNO6AMDne9G|sZcdpW>S>3?ES$(Pa{r%(1t{**o<_)9XTGbFE z>+O#=*1VIjGTYF*_VI0NjP$NfdhR=8X6)y@wYyGycmA^_tdd3i zW>)szbU4i12%!q=yG&m@0ent#>E__xu*)U;=+Th5#TrrVh_rx$Jw d^2)UM$9ObM + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Transpose Member List
+
+
+ +

This is the complete list of members for mlx::core::Transpose, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Transposevirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Transposevirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Transposevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::Transposevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Transposeinlinevirtual
stream()mlx::core::Primitiveinline
Transpose(Stream stream, const std::vector< int > &axes)mlx::core::Transposeinlineexplicit
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::Transposevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Transposevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_transpose.html b/docs/build/html/classmlx_1_1core_1_1_transpose.html new file mode 100644 index 000000000..839676ac3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_transpose.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: mlx::core::Transpose Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Transpose Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Transpose:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Transpose (Stream stream, const std::vector< int > &axes)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Transpose()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::Transpose::Transpose (Stream stream,
const std::vector< int > & axes )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Transpose::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Transpose::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Transpose::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Transpose::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Transpose::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::Transpose::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Transpose::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_transpose.png b/docs/build/html/classmlx_1_1core_1_1_transpose.png new file mode 100644 index 0000000000000000000000000000000000000000..77c3b22881f393c2459b0d6d2f2a2bc43e59bdbd GIT binary patch literal 914 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GV&o-U3d6^w7^zRr8CAi&nn zuX*cx{iDYg`xuU$Q}d1P+abWO~gW7MRwH70d+&!rS+FRP&IVcNexN8Q!3zZ!LFQIzV+9J{rj z*yg6Zj=1K(Rplki^svxRxw*1@A9sdT>Z-gvp&fO*@tS$z>+(ypmaGbmu6eEIb!oj| z)TAe>TU1VZXLwGUUbwaBT%`qr-BI2In`2jY-EI3K6`Z|E>SA`l+8vz?bxzy|yfYXp zICUBH7qUK3-NLY^MXX`^1*Q)|Q4DcNlJySS&*vO(*6~~nGyQa+w&$ezmA{_OD*69q zk^f)!$xCL$Coel#zZcBq-)^IQ-^BM`-IT)e zwU&94`qsup+G)k@;`#DC^!%=Tt?SorYCSi(Ibl+t$xW--yRU}rp6j12o3+EPBWUGa z$4Py;wjXTs!dLH+SHJn{N?7i8(MQ}-cPHJC3jH+qdaUfD>Zqq%E0%4Q&8R)R?VsJf zIUnY|edbrbWtB=v-XfLSIpY7>{#~%2zkTh#(ATxA|38o4s`A%rp{i%xv8$&npFd>y zA!6R}Q6l(tne?S>pS35mT-G~oeWb?pgJ94=!|e~_+9^R&{F7$h%xIZ5bLR95M#i6u zR(_rJv8esbnSE#2x893PNV~Rp`P6q0-}gWF`?yI}{PlI4wO;2J$Mg8yP4)bmTz`#a z$CX*v)i)<=?|Zi4_u?exrw&H9qotQ|q@|_4PG23fxnF;8N}8_kwo?|fi#M-*HFIX@ z+tpE}6UC$Y8;v&a&DqYlC+haLbDOWOT-)rlp5xVnw6HDeQK2Wd@7`;5w&q4!PC{xWt~$(69C3N#i#%P literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_unary_primitive-members.html b/docs/build/html/classmlx_1_1core_1_1_unary_primitive-members.html new file mode 100644 index 000000000..494b399d9 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_unary_primitive-members.html @@ -0,0 +1,114 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::UnaryPrimitive Member List
+
+
+ +

This is the complete list of members for mlx::core::UnaryPrimitive, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &output)=0mlx::core::UnaryPrimitivepure virtual
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &output)=0mlx::core::UnaryPrimitivepure virtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes)mlx::core::Primitivevirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html new file mode 100644 index 000000000..5644a664e --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.html @@ -0,0 +1,532 @@ + + + + + + + +MLX: mlx::core::UnaryPrimitive Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::UnaryPrimitive Class Referenceabstract
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::UnaryPrimitive:
+
+
+ + +mlx::core::Primitive +mlx::core::Abs +mlx::core::Add +mlx::core::AddMM +mlx::core::Arange +mlx::core::ArcCos +mlx::core::ArcCosh +mlx::core::ArcSin +mlx::core::ArcSinh +mlx::core::ArcTan +mlx::core::ArcTan2 +mlx::core::ArcTanh +mlx::core::ArgPartition +mlx::core::ArgReduce +mlx::core::ArgSort +mlx::core::AsStrided +mlx::core::AsType +mlx::core::BitwiseBinary +mlx::core::BlockMaskedMM +mlx::core::BlockSparseMM +mlx::core::Broadcast +mlx::core::Ceil +mlx::core::Concatenate +mlx::core::Conjugate +mlx::core::Convolution +mlx::core::Copy +mlx::core::Cos +mlx::core::Cosh +mlx::core::Divide +mlx::core::Equal +mlx::core::Erf +mlx::core::ErfInv +mlx::core::Exp +mlx::core::Expm1 +mlx::core::FFT +mlx::core::Floor +mlx::core::Full +mlx::core::Gather +mlx::core::Greater +mlx::core::GreaterEqual +mlx::core::Inverse +mlx::core::Less +mlx::core::LessEqual +mlx::core::Load +mlx::core::Log +mlx::core::Log1p +mlx::core::LogAddExp +mlx::core::LogicalAnd +mlx::core::LogicalNot +mlx::core::LogicalOr +mlx::core::Matmul +mlx::core::Maximum +mlx::core::Minimum +mlx::core::Multiply +mlx::core::Negative +mlx::core::NotEqual +mlx::core::NumberOfElements +mlx::core::Pad +mlx::core::Partition +mlx::core::Power +mlx::core::QuantizedMatmul +mlx::core::RandomBits +mlx::core::Reduce +mlx::core::Remainder +mlx::core::Reshape +mlx::core::Round +mlx::core::Scan +mlx::core::Scatter +mlx::core::Select +mlx::core::Sigmoid +mlx::core::Sign +mlx::core::Sin +mlx::core::Sinh +mlx::core::Slice +mlx::core::SliceUpdate +mlx::core::Softmax +mlx::core::Sort +mlx::core::Sqrt +mlx::core::Square +mlx::core::StopGradient +mlx::core::Subtract +mlx::core::Tan +mlx::core::Tanh +mlx::core::Transpose +mlx::core::Uniform + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
virtual void eval_cpu (const std::vector< array > &inputs, array &output)=0
 
virtual void eval_gpu (const std::vector< array > &inputs, array &output)=0
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes)
 The primitive must know how to vectorize itself across the given axes.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ UnaryPrimitive() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::UnaryPrimitive::UnaryPrimitive (Stream stream)
+
+inlineexplicit
+
+ +

An abstract base class for a primitive with a single output.

+ +
+
+ +

◆ ~UnaryPrimitive()

+ +
+
+ + + + + +
+ + + + + + + +
virtual mlx::core::UnaryPrimitive::~UnaryPrimitive ()
+
+virtualdefault
+
+ +
+
+ +

◆ UnaryPrimitive() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::UnaryPrimitive::UnaryPrimitive (const UnaryPrimitive & other)
+
+delete
+
+ +
+
+ +

◆ UnaryPrimitive() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::UnaryPrimitive::UnaryPrimitive (UnaryPrimitive && other)
+
+delete
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::UnaryPrimitive::eval_cpu (const std::vector< array > & inputs,
array & output )
+
+pure virtual
+
+ +

Implemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Load, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, and mlx::core::Inverse.

+ +
+
+ +

◆ eval_cpu() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::UnaryPrimitive::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu() [1/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::UnaryPrimitive::eval_gpu (const std::vector< array > & inputs,
array & output )
+
+pure virtual
+
+ +

Implemented in mlx::core::Abs, mlx::core::Add, mlx::core::AddMM, mlx::core::Arange, mlx::core::ArcCos, mlx::core::ArcCosh, mlx::core::ArcSin, mlx::core::ArcSinh, mlx::core::ArcTan, mlx::core::ArcTan2, mlx::core::ArcTanh, mlx::core::ArgPartition, mlx::core::ArgReduce, mlx::core::ArgSort, mlx::core::AsType, mlx::core::AsStrided, mlx::core::BitwiseBinary, mlx::core::BlockMaskedMM, mlx::core::BlockSparseMM, mlx::core::Broadcast, mlx::core::Ceil, mlx::core::Concatenate, mlx::core::Conjugate, mlx::core::Convolution, mlx::core::Copy, mlx::core::Cos, mlx::core::Cosh, mlx::core::Divide, mlx::core::Select, mlx::core::Remainder, mlx::core::Equal, mlx::core::Erf, mlx::core::ErfInv, mlx::core::Exp, mlx::core::Expm1, mlx::core::FFT, mlx::core::Floor, mlx::core::Full, mlx::core::Gather, mlx::core::Greater, mlx::core::GreaterEqual, mlx::core::Less, mlx::core::LessEqual, mlx::core::Load, mlx::core::Log, mlx::core::Log1p, mlx::core::LogicalNot, mlx::core::LogicalAnd, mlx::core::LogicalOr, mlx::core::LogAddExp, mlx::core::Matmul, mlx::core::Maximum, mlx::core::Minimum, mlx::core::Multiply, mlx::core::Negative, mlx::core::NotEqual, mlx::core::NumberOfElements, mlx::core::Pad, mlx::core::Partition, mlx::core::Power, mlx::core::QuantizedMatmul, mlx::core::RandomBits, mlx::core::Reshape, mlx::core::Reduce, mlx::core::Round, mlx::core::Scan, mlx::core::Scatter, mlx::core::Sigmoid, mlx::core::Sign, mlx::core::Sin, mlx::core::Sinh, mlx::core::Slice, mlx::core::SliceUpdate, mlx::core::Softmax, mlx::core::Sort, mlx::core::Square, mlx::core::Sqrt, mlx::core::StopGradient, mlx::core::Subtract, mlx::core::Tan, mlx::core::Tanh, mlx::core::Uniform, mlx::core::Transpose, and mlx::core::Inverse.

+ +
+
+ +

◆ eval_gpu() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::UnaryPrimitive::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
UnaryPrimitive & mlx::core::UnaryPrimitive::operator= (const UnaryPrimitive & other)
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
UnaryPrimitive & mlx::core::UnaryPrimitive::operator= (UnaryPrimitive && other)
+
+delete
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_unary_primitive.png b/docs/build/html/classmlx_1_1core_1_1_unary_primitive.png new file mode 100644 index 0000000000000000000000000000000000000000..191d394b31de10715621038ababedb088a3b7120 GIT binary patch literal 31591 zcmdUYd0bQ1*0$Oz*IJ~lZMlxLx3{$|O0*~-QyhyUL53g*#Gyh#As~cEAc?J2uNM_t zt5G3|iV#A8SR(`oBr2uUGDIaYVhAZ9L{1olaDWUX-#$qIsrUDO-yiS&e%ODGCl-xs z@4eQup7pG?Px<-!*uJ&kt!dMy*>2gq(SO>s=`qu$y?66(Z-7r)FND;B$3nmDftySw z6ZqxxZkUj9kL9-Fi}=6O>Yo4k=bzy7)%DmL;P+-H`TGV;1MlYlTxVoTrcEOaY}vRj z@W}KYh5D!KyJl=^t^UWoykTnFmj2(WR~56G!gJ+&KlqzXj7}qIMYX$~?R)=Drc1sM zlY>1I7$e_wk~HzMv_N_{ zwzilw`o_0RjO`*s@+VTv3?%uu*5Pm=%R*1`qsR~{`_NQ-Lv&&>Gt zPOQhYRYxlwSFgQZb?x!A@3z{Wd%msw@icvK%u&DNv;Gwrb98gg|NdiB)rz;xb~b1= zDzQEgLQ{=~Q~&WXc#w@VEY5Asa5!RUPSwf!qxIlNm^K%5(U8q23|%c{Asl~sZ{yv! zTRlHT38J6GUW~AzhVRPo@ah>V)s$(KgR4h$2ei$~?Oq|aRO2Cj&Tj_E+XHE>#?tCC z0kzT40{-%4B}!M0%B0|g=vCquQSulp)pe*G%cWlTH*K#>5ihs@ZjDA4Jl|+&ilbFW z8k_B|2@M?T3Qk{vbX&jro^b}u1b<9(mwf;#BAxg*hvgrcPLqs~0r{>6L`=&XJ(g?w zdvq~`vQf@1!?#4XJ-T}oZkO#eAyyEzh&1*WbYk%5EEz_PdgU0ApFKIco4SK-0(M&J z|35q)?GCW3T2Zlcnr&&U&Cyvmi8-@2C%H)$eB{TT{vLUYY{4GWR@-;-A58lZJlK-| z{YNwvN?24wd}|MFh9ti)fF1P7h0xXDkRXpoC;jd6QH1h#SiiMUx3#M4qCdyG@y+CC zp|!SEUHTMPTaU$DHtLNaoVqG+Y5qFC{E+Lq7T<7jOHr0M-&@J5ht?=}sMA*Ese2=| zKI+YEQMNZF(zM7|*~!}tUr4^;y=d-9asZjVE?8a_QmOkQGJT8vJ5rzebN1Y7IJB}; zxRY3jC0}0-6IQMvqx<4&5@n5KReC-b4w2-WLM9*eiMPSW65tT<@85Ci!Pq_kXCwc= zV7dRe8vmm?qs-ZU)2K)YLvX0yvn^+F><1DSQED(**#(>m7dkhe8};Z zM(w1|4llK#<9UjNx^4;Ft*z4@Vh4$0qf?36aN_VE1J@zv3%Z3+d(@D)b+8AYV>M{b zm_bV%&p;Y7R1i_|q3j@hCDm6^917QzNMzN|21hFdi^VKciu7_%AhRz!FUTbLk4^`C z*j31`Nm!)({Al=7=7Osi7JL-Mp8gN;@c*wr9vP9P1U|=UW?qnz<}+2~Z1+Ye@q?cUK}#*FRaSdZIcd-EVl3{@XYsKj#1RFd>!U2kDNdhpa_ zlb7{?4L}2b@K~Ih*t|&a!td}p+dlI3^ltFapH?WD#{Msz{#W<#5Nk0)c0C5TA_u;ELf3ETxN#xI=$X?(>sWV`(dp&BywI^D^e#$&SKN~FJZNr^ZW%_1> zoemx?$+*fmGD{K1Q*1oqLGVhY+=CzW3^jcYR_1hwmr8IyD`OTxG4Ka0=^Xx|*j*lR z&FmnG4%TYrCrB$JVEvex$=AosjHU5Cq*1z?A(atxA9Du2*V{lv()?QUroQ!6UGS;u z%!K|p&lb1TQw>Y5fyG=~Snk6uSdxPs^{Kpd7zAtP@V!5F#z?<;EsXviV)WtOzYD$d zQDDR17xXrNMc4W~c3A86-Qt9X!@(@gUGM;sF~ZBkg)N1;?aZ~KpAuS^Bnj)2 z&o}je&3!!4MdvQSZwR77pG3l0c$ek&k=vT9FKD|wB>Cb%Xr`u(P8=ONs_WaecJ9fl zy~;pp+ym)QJeEuOw=~C=V1G_8Ntd;e3L_HnY?b3yS&TGjH>BQon9`-7eAn2Y6+aZ` z$=szFa$WY(G{Mk6iy@`oq+nu=DafO8HSBh_ec{ZT+w=i%-1dLt#+R$t-td283NFTJN}DGh+SD*EnRXLVx^8sx2{OJ#1Hr3Ia8J{-fhv)4*2A(@L0Pj7-H=Tf9i+2xKQdtelYI1w|U`qUOY%( zi>?WsK2y9 zp#Kg;d=+-yMkCAZ0hW7ujC6Q1sO{qA*+^;Si8U`^?NafV#Sq_CzB_jw*A(fn-s?Mh zWZQetDl`|M!>%KnyaY*Qo?eGn?t<4b*LhujyIc_|xZjfuk%kOERNvI(X+oq-{W-*| zv|7E2ZV+*|)kn;Yj>5g!o0QLOtyIM zeNR#(DbKXmp?2@RSh3b6z58m>?uZa(gz@Xz^n7Iyl!EL{+ao7+nbZ&RGvq&7_jb+1 z-Y$PY#Bz!A@h;*>`%k|ItvG{1P4k$4vJ*v<}X3A->JZfH)UIVZn)1 zzTp#<-8IIYvumTuIMNhSL~`TJ>eH2eC?~opOGpks5^DS;Z|ONRe7>dt@R`9}p+BtM z-$tcer2baG&?vrX@HtiQDX7@F}tsH!#oiK}$!f}~=v zEGdV%FERR;qe)l}KlVh$k>y1)2%p=h1Ae!UpJVYIr{moj7*ms=Ke=tNXRrcb(4fhI ze&P+}CmHQ-!kyCux6!NDf^;+g|BIhSBEhH`a(;0%{4GDuU$c4#kktz%oo8DIciPM8 z0r{C*3rXD8mgUMU64WZ9lV*5CZ(kDpvnI?O7`sxgMTEJV){yn(@$+%emIF_do*d#S zez5N8?%G33a890jRHxZn*FQrN7OeVS+n`}j7gty6FOD2an^{H?Y#^849!&`p;eD(o zcbH`I=Hz_z&^jKeKZ9I}r8<*kH1MrN=G;Xz82LkEf0|C*azggnrj*F9jG;QgyuGIMfX;C+JR zM`qVAuIkEe;IkaU06exGj!WtN87fiw0#LFVgvv`3p>oXjE97kNUReO6wrnpDt z^m2w5dhbxjK+UgcCHO6~_0O2Ce=+zXOJ&2sZG-pbFzUOtZyg>@KB-IULS45cIHT9b zz@Q85!#SIyib?Xw-_7vE7^y2lfpbT(f|L5cSsgcisjt#k>{txpf?qRSB4>jT`Lt!_ zA~cky*HILK%sdH{wy=c1eJM5ygTs+@s2t#3-T~g_ehS6xU4E|gW1_;B_72`uhu36B zyLwO!MY?d&1(L3BNs<6y!g>S~YG+4vJHk*={RZkTUrJ8_!?}5y)zeU-g;+6$s?oDs zIFedg3oL&$nB32`ChV2UtDK(J0->fvu%XgR5mHzcyW5&zZt%h#-Bws0w~ZXWy#WMW zRDAdcMfn_=xBIv*t!p5wvC7F+$q-VmLkbEC zru{ban+MvGZ>}(rL(qDfoWnTN>9ejVX2bj8)0LDxSG%T)?`OxtgZ@wg7X$wGQgFld1!_@;qIrDl^bCB*2=ZJ3ueji z2OFgYrardaUvs98IpnBfsC@oFTXAUwM;5;jQkMfy=@H8wTLa%3O&6N#8u9p&G@Q5E z2Nrk(IHByi-^C#Tp2Pd$9LM7j9nb{ZRUd4CnL)2`1yhtHn1$r4=^iJz;_l`YQQC7c zbrU0@YaAi%A;#B>P}@~b0?jdRhHl%R_q?8` zW%hg~mb2AOO;>sP*zQ63vR>T?0`2$KeJCf$9c!B^4HvMY%h2!z-C=_8`!JEBVYy0w zfVg0o8Q=gCSRvaxW-NbIGc~fdWOrriQ79L4S!#|GEz9iX)OouBbC7{7NJ>hn1>%be z-Blgm2(eiae7+4?SAZQ&eJLJK&Z?$pbUuYd=O!dN>+C?hbG5DAZFEqHuYwq{9mI$y zDLc`Ee^$+3UjjmAB|p z;OBcr^MsohnRe5Yl=Ff2~C(duuc;Cu0)8WX3y~d4Fg6J?-9E;^)ZTN~1 z4xoqmcR8PqX0q~&bn%e0JZ5@A{4aFI;UyJ|lF4=5e80-Zu{Dxg(lC0l?$(L0;8E?d z5=lCXj(aDcLm5V>hQ2$??9X{L5;Pe%B;Z`s-T^E>(L-4AEKd2jS#wujtUI#rmm8co z-Ely_@6EyP7%WU;2brpcG^gU2+}GiBP!dRX=DfD{F!W=pNqx)~Q897ANI9*Y)?A z@ZslOOEX8y%6oLbsxi0z?xB}bl?9Arg^`7G<92hUJOw=8k*DscmEtK`{y39VIjfl| zO{|mu>}~p-FKRX1(5B%U^zXy1@owF6SBb{-8$J@QEGBKLyi~wj{cZ8lWwIYl|5~s| zCLH%(A1S#~j{dQV#_Byg`aUE_>op=cv-=Q%Q7EmwrlGa<41~a2NoDOu@qhRqSQWa^H|-sC#@|l1l~zoiPJajWxG~It>vh~oH*W)A$1XvSEbzI6;_dZY`!=J zi@3GeNA85(GKqskr5O9!YarTPYJn+6mr4*=?|}Mo{cRbDll?82SM0 z6*}SD{)Kg)4N=zv;$PVSFjxQChli(YDTNWsIyl9;%*v zhUXuY;r4jwKFJFr9_q-|?aKj`F7pwBGXKI%PR(+bkDpz>*mV&)e^}r7D1*9%9Yotv zB@(eD`L!q7sm7gq+%N6TTH15K)i$%fUHzx_g1R^90QqqLG-<5piY~pwdV>690V))+ zD6FW(OWV&Q#R;wyyO+k6VYx-i&e&78*q>9eRrq_hnghR{(51<`?4&;9xc90#3nb$U zg*xOiVrSlzufBT>JpNbOB(iwjmnRFGkAD{0$Wc;QFvFP3yL}uA(7v2qYc#x&7VWy30cfbvIG0 z_S~vS3##obO6TRH2BQ%pUoHlW(i+Y-057?IV%!g! zt;uyuJnn1a^qu6~iGr-Xr1tPFPmV5$xv5b1A3&;=8l9S0bu`Xm%?#+6P+)8a^kw~s zL)@3}fCzYa{Rq>0aozhbTU5V#&q`ydVDy&?PY*7M1zgX^W)^xjlnfa+4S%!hdGhYEdSALItbS=t4?Q!J!e|zyV#S}j@RI#x%cI}{*CylU z9OF`n#=f;t!pihj*N7c|A45iV<2&94>PV5EB>!$UePUu2c|JK~&^<$dMc)AQDXUMV z=o$HH9ZG!{?mjd>wEL`n%cu%&4-ViM)0eMghB`~+*cd4*)M-TyR>3OkaaJFPEN*?g zYz>_J6s^Ya1MKq63*nF?FcJn0IKLP|$&PkvY4o~6U6|-+Qt9^6u04^589bU|)ghNN zQ9g}L-03VuSuMY|xUt=Se^B4hZZ^x*YbLXsntkhO$7-8W`3gao3l@$OC1b_&$JS%k z@NCEd`?J4;@)&w2Q+7>s(ON5@SlA2!R%E~WPLY=@2Z%w={2BWED-uXM+J~O`Q#XKc z>Z=qB4OlhgUNx zC#sJ~Rk|A?UVKH}Q;@G^No$oiB>BUs$4SP27#O;0ICYG;W?3+dyqHw^$goKQ3bq+h zFd!=44NJs*jgmo#i~?csUNUlRlcsIp`k>8Dj&WhEeS(=Ja;&#sB4vT^>sTL=oiBb> z&M-CIKKXNy18fBh5o`iVJG9Fg6S@$cPI5cExEcV-mXrnrNH(ovEDoQQFMND*C-vh9 ziwNh<6tk&ubboR_CR3UyuP<0EA=HL=>30O;@+Hchz`GxWM6uJ?Owu)lbMV3CNAoC3 zYk6^DZ|!rC2#aqc^%#po`JNQwTjU{?>6XPvO$Cz^sfo8K>NVZ~TxjqDp*OTIRr@jE zbgxAlbRZd|G{!$$^VSX<=K@U@qOasddwP_=?P;4Dq(7Folx}trvs=s_8y}?#4y5*; zkA$0(oe$a&6d=K+oR_{MWjexMkTyMj2RUS|@iQwr8Q;CS7IqPO9;$_}Y@Qt9wNE=qQ1e8=eZQc#}KmpD7Hb)}eV##?0;ZBzTDCbovw^Q3FIRFR|G=-|vz zMG);<2zHS$WaAGq4tyR8e{VIM(tmlU*eBmRZp=rhED@0?gO(Qq4Oxp7iYR72BXth= z0R0F4sXxyCJW#FmCa1j{9#(dbxpc%)UBdz#j_ft_zDyq^+gw2o^|A3CysuC=daP6d zWJWv80Gk(588)$nqbv?MvzW9_F>@m<7z5bLJ_0rIvMU_#CYyk9<1V29=kS1**FLzA>TOsqQOPJ!|~T1iME!|^Bl0T?3|v+$8(|bj?hXskagr38KnMy zWH}fJR!s-W#mqwX6-w#Se6;HZi|{dW`WGNbnW=X(qvdVa7h0(M>^0k^4%S^7r;Ml+ zN%Aqpf2SDyW28Pi8v8e{-dq0}WY**gcCt@xh*0XQQ0L7L76L-MI-_GZDwmfOYC!S-tOO>sz zrmm_&pC*@3}Xe;F@+ro&AL}( z{bch_(emhCGY8)215;35AnGdNBO_I~b%drU0BLV$*CM7p7>>|&e`=qT63UavyFjrl zA9dYAR(y6}&USnb%Lzo%E_P7x^4PwHr-a&KohO}eDHfC<;J0AICO2mj$69t{T1q{5 zpeX9d1$;%E^VOLTNrf%TKB|?qUDbdRf~DwSmOSHHjy4`XTp2|H1gHo6UI7H0sJBHE z056_g>zSf>%NgJ;)xumcf%>#^Nw8AwIF7rMa3Zud1{GWS$rkK7BAt1FUx|~J&Y8$XE zE@P_zhgjNrebwT+I~JB^Ijf&QC9P)%ISl3yo6SuOq)S@plDlffF(YowuP?ya? z6qutOOM&H7x&3^gThrZa{T|rvCFb_|k)e?qnaPEpXKX&{hgoZuORcY2gLy1#L~Pwmz4#>yuXOcOMJsnrwMaG*REK3zXK7XTbKq9Z!9rb&)TwsY6my z=FJ6_J)~uF4-eK5u#C;dey$7&}Fwt-Hm&LCrKzdF`c(Hgo zCxlmt;J7x+s55!J1t9Go{BgV&1lQT!FAqv|A(6yRy#b?ayQ|pIg)i;oiPRtYaBA*r zK)zbo6oMQYO!)Ks21Nd~fNIKJ4YUX3_uzN~*UM;wvEiA9G#e2;Ca-J_x;*?!G9T@0 zAC=)r&Ac)z?R{8AHT-n}$SR?Ro2cVp#lhjjpj5Xg%FGS)+#~F3a4O7;ItceD8}{V^ zQ3s15Qux}aw*@z> zqzm*QDXFUtCw~n-+>aIS{g5|wYqzVQZwX8YWJ{ZX9HgI@1?9PiX4p&n{&MGmjU5> zpdgsjZ&f+{DhKpT?drvmnNTD7@i@Fl93AX&SoUR(5JNT;aGB!K<8yvT+@aLo)Za-%jQ}{2+BvcSB8$lxS^9`Qg*-* zpvJtd{VztoDDQaGb)QqirpWskWBe_#h$*C-Jh@Un^-KOQIh3DB#&SUznknENUB<>6 z@|7i@n3Wls&@qdArjogjayx8ooFsg0QfaJ3cAQ`?`4?&OmBS&3d$0Jcnev?`Gyz*t zu!x`2UU9VU!wXHEgs0S1#Smq5Gsmd};P&%Y7ZUYaxR8G!E<`z(+P?-cPrKMSum&zZ zjIKSOt$rIJu~&5`RlE~!>w(?`knJhm9F@iPh@7$k=Op_OQNwes>`khShCU6QzW}0R zzeWLfJyLG6N?Fs)&RUnr+*jNTnB6EVvx|;cTaM)zA55glbGo$h*eA2gLAi7PWb;kD zP0?FV@t`Qjj`FE4{hyfz6(kJe)zg8Kgk=VxYS+*Eor3D69Z z3HzmYYvl((){V%vdLsaxyIn=t2p@k%CNxF$L6~fi%=KlqqP%lJ2`7m-vI0PAO4YkP ze*<9pZ{kCUPxvflP5InRg^^%*Ps{SdIGKU;ErHdJK>ADoNcz$^l2+;9M{tHe$BA@Y zH8vA%9JFRI|}-E~mY)*C=DyuMJ^cRNHC z!~_5oQxZl8h$vV^0JS)3!o=o8INPZS7Dh=(jwf75@t}I^p4tKyhrPA1p#*=zjz7ma zMyk>mO3v3mY&)HF-o3u6w%60cRU+F8qiGAaza5wV-b(_^G z>=<_nd&@&8SkUkR4-}8=IZ%or@BZFi8Re+yD1K6`+nNwQ$Jf*eR91w9(gKlYZ+ZyX zNLfQ(kVZBhqWgyrBe0FWL|?hWOi8utb_2zGj>i4EE&v=vkIA@uilU5RP@xwSi6foQ zx_)(K1@#5S5QJW?X$w{#nnnE#eim=QppS)pFy4)A9}0q^s;)0HfZ7weA*gRSmA&*k zGavkg0!ed<=q0Jc;dqf(n`a36%MZerbwho|ig&^eaOd3pm0(Ps zpR5w9%ix7W&gP}ivCS<8JvB3TAuR?vvg*JF*z_tTGDXE$ZzAWUcdjza;GW=46bw`; z17f6`#I|UE%qQE(1{7Lc0m9?TtTRr5=bvJ`-Me*JT5@BV=5Ck7Fh5^Safa=yk4Ga} z=Z-{8GF0wVCsqjXzFu-RrIGs*K0#5Ar)D@6&*k+jh1yjiP~#JA+u)eaonmRmF?O<^ z_@OX{#Yi>ZU@g{V%=vxzARNgK9xcka4(O+-8=xV_UHCgg6n;Wc)<=q?94Oh$F49*6 zIj=8zd$75xf6QfD8HD7{8Sbcs`KDLJL{l`Z0&dxQG?Mm#9&YcQfI9YaHMl{)7U~j3 zS4eJOW^KIqnTma+X-8EgPjfhFaWc8VAk>sO>M(7hOd$C(-CKop9w3nA@%JocYXh0{ z9lD_h=gm=-n3Y_;Vanf~mQ?QA8Yn81$Z}d;{#*(cL-a2zb-HuRPqi6IAo{kHSi3){ zVQYoEIZ+wQ;eU^pOU&{230?1$QBmOG;0q$gpz{r=9U+8N*nQF(Q0+~itzK`IDEx&W zoT8^;8FCt&s_=Y8%AH2>GNki>)vpR?2T{LuBQ4sJd+c&Uygxz!BtDrBCmBa-y7va; zMASFj|NSmp62Cj^N{PD9*#n)ri!H4#xvf}cbF92~h}$}VS$DD?8${EG?^%C;P_SEB!D6N zTW^>*_{~J9%W&vP&yNAJzWfYbj|BIQefm0;eEp#Ws15W4#PWQcgg`|t6GLS^haR^i ztK&=yoJ{XW62VB6MHR968DGM&kz(w*>h%fkCfBj18p}`!RG$F#G^DURws<6B(__$G zHa+CbWzmQB>SVSkM_Bk4OV$Yn6maLsm46Pl$^E#vQt%Zz0RGuF8UIYYO;MjxWA-WY zjm08`w#`|h1=7CQs%S$BOGh1+vrpDmWX+9itJ>0!h;H{pNg7?q=*8g-KIpbA#bLa@ zuq{`7_-%TC{XNHLL23OJ)kJB{s1X62%LJCq{dFbj5$IpeaD)#bwr+a*VwSa#jkXLY zpbb~rg0d?pc3Vawg|Ee4nHyK)kIo&hCVZ#&f_1SRU8xdy5!z%az)NpabW0=SO~97s z%Y|osUMG{NJCLX4LJflFHcUPrl_<2JL0T;t3-k>wILm8i&mV7HoOr76lgtQD&8zG z74Jst>6DdTb3GO-dp!vAkLZ&P?Ybu&< z2y}3l)uug;AgX(t`K7zSS>gIQpfze*&;b`Zi|hPnJ^1n4W4_zYV*YAOLijoS8z0X3 z{m{B#K!IPeH0TL=M-O3Dws22|(?{;jCA?ZN0_G)Su0pl5;7UVFbj{;J< zUp@*;MwVh12XFd1&h85fg4FKrgHniCl>Ipq!wOuiqi5P!M^2Q_`Q6K+rc;Kj?Jf8d z3^+^)f%h_fkQSCJ=II>7O53amuIhu%x5#B_loIF|#iJOS>$Sm5Y6-*41tXX{8O2Xg z8bP@c5r0IF&@+3zqegUDcDe2CdZ6!k#&75?Am>ZBg`k6dQA_Xt+g)GnS~8BMb9brXnQ8BI@uWM)`vVn+PVa)I@3@ zNT}s%Om?+C&2h#LKo5oovzX5&s&JxX{QA4qC08Rn)t={3Jbiy! zyTrUG^y5Y3ezL~&two&2Uk;}&AUkvH(bXHE>Qb9Izw1;&MWSxV(f%B0K*}WS(FTO& zKD)|~EygNmahLQzn=z2d#&3hRB%UCIy!T;1oCd3CaVvbA-_JX+5B|0Ky0h(D$lRvn zTrBzIWdAksHbs}vE(a;=yHO-fLv8Z_eBRjjqQeX1^TO88lQx2L6QSCYiXeW?i2H~; zNc`>G-+O7Mp0@UIGs$b`=ryyR*F4e_i1LQc=Ygh8Q8q&~+rE?fv6cSaW^q{c9`VHw zz^|jkQEH0iVtZiQkMR86!abgLzz?y&pHrzHN6RM73<~+An=O<>(|GPMJ>| zR(EYlCv1XGQqW+g2gvna8eI93;(v-(siQ4D!WobVOi2^PE-<1(sDFdw2i<=H5Wdnf zkCvB`w%C*I>9r?@atxXWus`ko%N5dP)Ew!9#YrV zjy)T^^s@t*(SK3&X7Zn3&<%>JQ@{|tiU}dwBGpx(VC~wEw8NXLpb;38{Aj&xK4$P^ zKo1Zmu=-1&5_~vCkvkU<#JFg{1GMz0|ENsoHS%kTvz&a%p71K8JX2T}f5SWc%IJ3H zdatm^r%5aH8Hm=e_r~r1vX-Py?>@7L1%%faCRd6xM~1KNnQ{0eN#9-&j*hj8uByCn4j0cqzAGl%WG# zG4?{O>V=xU^ahB#i(=j@XOu5{lXsp!C;hxEN`13A5b&M|9x|Vj$if!GS89g?Wmwbg z6F-x*pH&pjVXbam3Af>OcJ}8i!?f@gb~CeB=ds_$UOs>t8JzizIX3_koIRja4oL>- zB{FIU4Q~BR5D|}R_Z?=)PRQ$lTs{#%WQ1wHH4!@hGW2e0PE$ojr;GpO1#$9q;A+ z1({%@C;%*b#0s(U2hMr({2HlXto2csXi;eJu&A0?paXRVb?`CWwl=P?r6B>-X*MIn zCe0!+8n;?rTKnGZN@Q#W7@30qzmChEViGztZG55~h+}KFp*#hW{Pf&baZ3Pn{vNIB z$9?XX*e`WWH{Mt8FH&Z(8kR=3J5kSo;X%&uKfijO4CbR01Cgw{=UMkk+PnN;I5uGD zxS2CHW0>?}lA3C&dFr_q_66dM>HO)O^O~q;Au4yYhqZIGXiTmo!kdSr;2%Tq#)!HB zY42s%XHm%Hp#BDtzGKF%P>#b+)Z-Yb4Ve8J`>`AMzH3wMdn;0x$N4T1iE-ytvW~WO zV7sS-?VkIZZTF08*6m(BG{)>SzJL+rSk(GF@t*)WhQ5p(dsoFIHx6V%?G!KgP0Iq@ zMx+}ZrCt18S-aB$$jyCVsDpgW{6puLeSL_!Tzc{X$%PlS?XK=wGueGjW;3SfqcMSu z@qq>J6iCw@qSj>C&x-(K;lCXh-8P4FXevaXhisSbaLl`p{jL5h5p;UTcRy*z%^np@j#oUctFpfvZC-h)rHMW8P0 zem%_6hxbpU#W2qgR)?dwt7`Wax0`;rjin1Nmon@5ey{vwF$ta%CW*Y7%^TM{xpc$eSREA3d3=4Uy3LKn{@Zg@98>N<`X^blp|wq~Y+By&gA@hJC8TgKy|p#7}#XoaZdhHgJ|-DpRGg_S{? z!)VU5bk0FdCTMTm5D8zzyCCh2-NB(*Pk<6Xh#g#)>DqcI+MBzG)#LesJEq#rwWxMM zzevq;het>$>9W`|NUS*#!UO36OXoKp5b-kmH*#s}u1^J}TYzHq`%qZ)>aZlMcb=kj zB+}^!cCPrC0G9#?B#;qN1YnRz#=^6o`_Q`fe4WI&ZB}_@1h?l#5|`yxK-u2M$3=8V z!Okhl$-b`{RNczIjkZj;7;+?U0L^s-g`=d_wN=~Mc)$7~ZlhIF5N3fKK#6~5_X04G z*6GUFKw2+WFVNAwBs#~-XiI6tf{P2Jc>Oz|AE$gmiuhVc2F`UE1m zTtsr41m40^{Y-N4a82V}NJh5t=$C3-SO31EmJ)x|B{BY3ekD^s>;C7m>gB4UoSaXS^9y35G#Zq2nKfz!cy+e$o6R zTbJ>w1DTqlXz!maoHf9-ZY!{IAIfo{9Q zCa>9`UKALPmm)3!Yu>!NWV$u=+?INsXPSQ$%6Ghjb%vDBI27s++?a2Nmiz)o7iP_6X?=H z2`Y0(mB~(JcVo%f2v%2_#AC-(wfX)G!&i+8o+ckNlpvWWG0#w2fd}rPVQE8DAg)XE zvJ*Be%I5|`eQfi_HN{TuA$8?egPJmH+%bfMCcFuXFHO(_TFU>@F!qu^nA-c2SK1>* z+ILRnwx%QT>}u=gq$`YVm5nzld(Ha8?xf)qiWrR9YV_HOS#K6{`B%`(v(_-(n4UO4 za=G}5bhoexnSa2;b+vG>gI-*qRi?~qDbwP|?sc(F=}lRlj0Lxuj1O1$H_aLEfKV18 zH@pCC=zqBF2z)q2IcE%LCBuCS>IHGBurFK!v>JOGUzlWCUf#)cq?zocy+{+-XCrVZ zPxKC|dBT}SrnI?)!uP#nuy}1G(uZ45${fU7tTw`Gc}KjrappHGcht(9&1d zS6rGU1w))JTT7K-R6PN}0h^b4c*1d%%rl@j+zvEJpE*61}g_b0^zl!Mn-$y7=5 z&W5gE2fBOY^QaiMRG!>gK?Hu()z+R1BvlDlZp$o!V?nc*16e=cDGI+oj+)of$M?r9 zG4Ix|1Yoy@Fm`qOqU&MZP_h)SE9Q}ET{DrvS3{2Vc;T5#D5q^r{-DT1nMX;~#Qq9o zjOlV4HZOLJFx6WK|5R>uM?pm8SA9JclC#lhrc9%rT}j;rn$-@xgu-ab&TN#-3CG>< z%UZ&95M2l*a4Dgm2lfH}oMz3>q_0K$F49Cb#}>z|meH73L1^c+Ab?DD1TepSnV+@K zj$sqW02!_QEOmQ0BiVb@fFGGH@%hB256o6at}E)Oi*?5(^H}x2j4~$Vf*vrb(ho)^O@1E{E zXP{nP@!2x$zvTC;)dY7hK@hyx0o+*QI+>e(>1~S2A3n3}cQCA&>wO&=Mq`Up01+X6 z;n}r5*omBwxW z4LAtSg$|{{vjJi(iyj$Z0CGVY|GxV_Kv{)5Mk;tNB$DnWU*vfAR^?@E9`B8=y{tL4 zEguDg`FP`QB9`)=ROJZT`7>AV^Uwz|$4CELC*)-UWDA;z$1f~n>=ZvvklO2#UUCcK zrG3lLpo)TAS%&9gRG{(Wm?hp(iIa;eQ*@7hXx*bQ4nRXCxIIx;fd2j!T7hxKL+K9J ze+#aynFZU|PL!&`8?rEA?7q`rMeHt5fQD$JU0bF{Ahkf^SX!SM|I1L=i8PN;ICDJW zP2|9NriZ7EzaoB@a*S{IqJhXze(zqo1LV3tkLxMru;#v&)35t?WeeM)?lTQN3cXei zzB0z*2FTWSn=v)z6venwtMAqG>0EFTk-E8akWY;Ns5 zBN~ahrLu|SeZ@Dtb2#RTxKFFo3=G%JQZkLAizHo!dEo!jOKNuQ&; z5x5Ui*(IN!uZsbPw8)>f;8LvUyHFT6nQfb*$WIJdV+8pf>J^&Ksg{78z+*=h9x0L} zRK45wEtplftKqzP24`n~$MYn0zwW@DwC7#Frs=@^{ow%2W(1t{zM0}hSmT@>5wWX9 z5kkK9^hF`|2Lwkifes1LVOU8u7Sh&5pk@+PeFr*5(yg2ub$ralFq`J+69r!Zaibm z$jpJR6QB>pJg;ro+N$L%wQr|W-zTh zgQ;O9;%HIeo<03Pp;J*0dO(Z95BvbWL&xFXG~?kNpv5l+^ornp&6D9HmY0mR^9^h| z628`m9c!Ml@IPRtRdh4xMhmxgqfyNhMam{7ihv>Ujb)96z#pW)W(fhxh_P=UkM@^NWLpT*F`$B?bM za^|<%c#!U#wj?SY*)zZ4?E>9@;pY3FXvKSlzIC~(!_(I&(Yx1P92Sf6C^A$oR*@a` z;T$QDd5}R%A^?664oLT364gxA9c@tRT>+iDE&AW~O{%-P=}f6PR2?U$93R!L;aBDN4|6tu(HDiFPCUMIl2{ k>_J8Sf4uH<)NVzc_?!2evfIFw*weOb^4(ape$V&+2c?alh5!Hn literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1_uniform-members.html b/docs/build/html/classmlx_1_1core_1_1_uniform-members.html new file mode 100644 index 000000000..97eece0ca --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_uniform-members.html @@ -0,0 +1,115 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::Uniform Member List
+
+
+ +

This is the complete list of members for mlx::core::Uniform, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, array &out) overridemlx::core::Uniformvirtual
mlx::core::UnaryPrimitive::eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out) overridemlx::core::Uniformvirtual
mlx::core::UnaryPrimitive::eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::UnaryPrimitiveinlinevirtual
is_equivalent(const Primitive &other) const overridemlx::core::Uniforminlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)mlx::core::Primitivevirtual
operator=(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
operator=(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
mlx::core::Primitive::operator=(const Primitive &other)=deletemlx::core::Primitive
mlx::core::Primitive::operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os) overridemlx::core::Uniforminlinevirtual
stream()mlx::core::Primitiveinline
UnaryPrimitive(Stream stream)mlx::core::UnaryPrimitiveinlineexplicit
UnaryPrimitive(const UnaryPrimitive &other)=deletemlx::core::UnaryPrimitive
UnaryPrimitive(UnaryPrimitive &&other)=deletemlx::core::UnaryPrimitive
Uniform(Stream stream)mlx::core::Uniforminlineexplicit
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)mlx::core::Primitivevirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::Uniformvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
~UnaryPrimitive()=defaultmlx::core::UnaryPrimitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_uniform.html b/docs/build/html/classmlx_1_1core_1_1_uniform.html new file mode 100644 index 000000000..45442f1a6 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1_uniform.html @@ -0,0 +1,352 @@ + + + + + + + +MLX: mlx::core::Uniform Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::Uniform Class Reference
+
+
+ +

#include <primitives.h>

+
+Inheritance diagram for mlx::core::Uniform:
+
+
+ + +mlx::core::UnaryPrimitive +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Uniform (Stream stream)
 
void eval_cpu (const std::vector< array > &inputs, array &out) override
 
void eval_gpu (const std::vector< array > &inputs, array &out) override
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
void print (std::ostream &os) override
 Print the primitive.
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
- Public Member Functions inherited from mlx::core::UnaryPrimitive
 UnaryPrimitive (Stream stream)
 An abstract base class for a primitive with a single output.
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
virtual ~UnaryPrimitive ()=default
 
 UnaryPrimitive (const UnaryPrimitive &other)=delete
 
 UnaryPrimitive (UnaryPrimitive &&other)=delete
 
UnaryPrimitiveoperator= (const UnaryPrimitive &other)=delete
 
UnaryPrimitiveoperator= (UnaryPrimitive &&other)=delete
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
 The vector-Jacobian product.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Uniform()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::Uniform::Uniform (Stream stream)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Uniform::eval_cpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::Uniform::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+overridevirtual
+
+ +

Implements mlx::core::UnaryPrimitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::Uniform::is_equivalent (const Primitive & other) const
+
+inlineoverridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ print()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::Uniform::print (std::ostream & os)
+
+inlineoverridevirtual
+
+ +

Print the primitive.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::Uniform::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1_uniform.png b/docs/build/html/classmlx_1_1core_1_1_uniform.png new file mode 100644 index 0000000000000000000000000000000000000000..1fe539076e6fc06a10e3b315eb56c94130e41238 GIT binary patch literal 876 zcmeAS@N?(olHy`uVBq!ia0vp^(}1{xgBeH~F+Z{fQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-)JzX3_Dj46+y(_fHfXB^$ z|2d1w|LSorjoUvu$>dw=psZ+bG@7koYd1hIz4x71jW!sc3%ipaH`}RBPuAcpsUE5tvy@M=kqMq_u zt+=*tt$dc}B}ToqtDfGzcT4wC^4g+kk4y8?w_IagTmGo(o@Vx>Re^gSd0Pigs&mWm zoHYHyBo+OIs-AJjuAZ`d{*a+UT)g3wi*4||m5a~ru7PpMv*$={+pC;d6|aCz~iua{5T z&As2Ia<#_!waTN6N2{iaeg3&*MnV16rv~l0`TN=Y&2R09-llu^Z`98vniYS8p55K@ zY>(%y{;IUfl@%}ky4kPWb*(nYa{Ib4$;%dHLUG zi{4vx_sz?`u+Wxsep8lg|NJEX!SsJMaoc{axb`~i|GD_BDu3sI{PH;K>7366412n{ z4^;H5EZxcT(sbFi6w^i57es9oX8eOUSdb+DF|VBxB*i~z=FNFVdQ&MBb@0Ax?K{{R30 literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator-members.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator-members.html new file mode 100644 index 000000000..b98657fd7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator-members.html @@ -0,0 +1,98 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::allocator::Allocator Member List
+
+
+ +

This is the complete list of members for mlx::core::allocator::Allocator, including all inherited members.

+ + + + + + + + + +
Allocator()=defaultmlx::core::allocator::Allocator
Allocator(const Allocator &other)=deletemlx::core::allocator::Allocator
Allocator(Allocator &&other)=deletemlx::core::allocator::Allocator
free(Buffer buffer)=0mlx::core::allocator::Allocatorpure virtual
malloc(size_t size, bool allow_swap=false)=0mlx::core::allocator::Allocatorpure virtual
operator=(const Allocator &other)=deletemlx::core::allocator::Allocator
operator=(Allocator &&other)=deletemlx::core::allocator::Allocator
~Allocator()=defaultmlx::core::allocator::Allocatorvirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.html new file mode 100644 index 000000000..4d6d58379 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.html @@ -0,0 +1,338 @@ + + + + + + + +MLX: mlx::core::allocator::Allocator Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::allocator::Allocator Class Referenceabstract
+
+
+ +

#include <allocator.h>

+
+Inheritance diagram for mlx::core::allocator::Allocator:
+
+
+ + +mlx::core::allocator::CommonAllocator +mlx::core::metal::MetalAllocator + +
+ + + + + + + + + + + + + + + + + + + +

+Public Member Functions

virtual Buffer malloc (size_t size, bool allow_swap=false)=0
 Abstract base class for a memory allocator.
 
virtual void free (Buffer buffer)=0
 
 Allocator ()=default
 
 Allocator (const Allocator &other)=delete
 
 Allocator (Allocator &&other)=delete
 
Allocatoroperator= (const Allocator &other)=delete
 
Allocatoroperator= (Allocator &&other)=delete
 
virtual ~Allocator ()=default
 
+

Constructor & Destructor Documentation

+ +

◆ Allocator() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Allocator::Allocator ()
+
+default
+
+ +
+
+ +

◆ Allocator() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Allocator::Allocator (const Allocator & other)
+
+delete
+
+ +
+
+ +

◆ Allocator() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Allocator::Allocator (Allocator && other)
+
+delete
+
+ +
+
+ +

◆ ~Allocator()

+ +
+
+ + + + + +
+ + + + + + + +
virtual mlx::core::allocator::Allocator::~Allocator ()
+
+virtualdefault
+
+ +
+
+

Member Function Documentation

+ +

◆ free()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::allocator::Allocator::free (Buffer buffer)
+
+pure virtual
+
+
+ +

◆ malloc()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual Buffer mlx::core::allocator::Allocator::malloc (size_t size,
bool allow_swap = false )
+
+pure virtual
+
+ +

Abstract base class for a memory allocator.

+ +

Implemented in mlx::core::allocator::CommonAllocator, and mlx::core::metal::MetalAllocator.

+ +
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
Allocator & mlx::core::allocator::Allocator::operator= (Allocator && other)
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
Allocator & mlx::core::allocator::Allocator::operator= (const Allocator & other)
+
+delete
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.png b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_allocator.png new file mode 100644 index 0000000000000000000000000000000000000000..a57dd94719d8f701ff258f2887d9d6d2ee95036f GIT binary patch literal 1087 zcmeAS@N?(olHy`uVBq!ia0y~yU_1t72XHV0$t_=0KLRP~0G|-o|Ns93nJ?aaE$u(F z+`>S!Kw|ot7Y`mh0E(NQ4O;?~<17jC3uXZF!N8np_7w&O=2M<7jv*C{Z|6Row8=_< z)w*5l^`8I2@vJQZP1-fl_oT{oqK=$V@tM`KpFw{4l%Rd$4fc*3-{z~D-YTmITAbUi zj>NbpbG_vL=02hL|MiL&w@C*retY(#W^ZbiM=!(s&xe=vZ2r{|xx742TF}Hs1TbY_Gad?rv36J+otn^>#0|vQ1jGzfZr+xAa@lO}0HgTiz-MJx&Xc zx#qTev*>MI_MNkj?cKiW{<>rHg>tLiv-T|b_;IEARps8&%YKjVoP6nW``$z;YpK_z zg&U-9hhE)$@uu|b<&T{f<-Ygqt?o^JvdLgpd9wBT=l6Q^GyVWw`hCeU=ibw7*0JZ8 ze=WaSCt_o2d+XYps;$>=|LE`BR_l_r`mp;BgHXeX zC2kEb?&Q4`RIpens9;sz4P=?E6nwz#(aG@8vzH*=f9h@7?!TN2`=;;TWE7ldEcL^e=khdK8=}W>-qKL{!E$ww#H$RK}?x*bY-*lgiKK})sZ^R zCvTa3{`9Jq{>c%0`<8X5e=_?~7`1=ezU@!X{QUYMRQ>zw=fct9f1WwK&)5@HS-kw} zhuGe`3%p%DgAT7MoU>l(ZT~Bidl!FQ|9EuW%5&@F_oZJjswtKK{B_p7vL8hYwwuPS z-Lv=d-xMF=dWn*a8lG7z!jk81xy{FY z-$noZd;O$muRgog%8xlu=b9(IKl=X>YgOSlmE1mYsq3xw`X5Wbt)82($AW9#tYx}O zXU?oVf1~W$oBXpnZ=~b-N@5>ZPksNcL~dVH{n@rVoV$X+p{e@vcZK_66<2jruOMJ1 sQ)$J)cqi4KD6wyv*JEHTU;iX;ll}6onX+dMFsn0oy85}Sb4q9e0E06ZWdHyG literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer-members.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer-members.html new file mode 100644 index 000000000..e3bb6d0e3 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::allocator::Buffer Member List
+
+
+ +

This is the complete list of members for mlx::core::allocator::Buffer, including all inherited members.

+ + + + + +
Buffer(void *ptr)mlx::core::allocator::Bufferinline
ptr() constmlx::core::allocator::Bufferinline
ptr()mlx::core::allocator::Bufferinline
raw_ptr()mlx::core::allocator::Buffer
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer.html new file mode 100644 index 000000000..263092690 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_buffer.html @@ -0,0 +1,201 @@ + + + + + + + +MLX: mlx::core::allocator::Buffer Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::allocator::Buffer Class Reference
+
+
+ +

#include <allocator.h>

+ + + + + + + + + + +

+Public Member Functions

 Buffer (void *ptr)
 
void * raw_ptr ()
 
const void * ptr () const
 
void * ptr ()
 
+

Constructor & Destructor Documentation

+ +

◆ Buffer()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::allocator::Buffer::Buffer (void * ptr)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ ptr() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
void * mlx::core::allocator::Buffer::ptr ()
+
+inline
+
+ +
+
+ +

◆ ptr() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const void * mlx::core::allocator::Buffer::ptr () const
+
+inline
+
+ +
+
+ +

◆ raw_ptr()

+ +
+
+ + + + + + + +
void * mlx::core::allocator::Buffer::raw_ptr ()
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator-members.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator-members.html new file mode 100644 index 000000000..7a70d6422 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator-members.html @@ -0,0 +1,99 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::allocator::CommonAllocator Member List
+
+
+ +

This is the complete list of members for mlx::core::allocator::CommonAllocator, including all inherited members.

+ + + + + + + + + + +
allocatormlx::core::allocator::CommonAllocatorfriend
Allocator()=defaultmlx::core::allocator::Allocator
Allocator(const Allocator &other)=deletemlx::core::allocator::Allocator
Allocator(Allocator &&other)=deletemlx::core::allocator::Allocator
free(Buffer buffer) overridemlx::core::allocator::CommonAllocatorvirtual
malloc(size_t size, bool allow_swap=false) overridemlx::core::allocator::CommonAllocatorvirtual
operator=(const Allocator &other)=deletemlx::core::allocator::Allocator
operator=(Allocator &&other)=deletemlx::core::allocator::Allocator
~Allocator()=defaultmlx::core::allocator::Allocatorvirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.html b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.html new file mode 100644 index 000000000..8d6363adb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.html @@ -0,0 +1,219 @@ + + + + + + + +MLX: mlx::core::allocator::CommonAllocator Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::allocator::CommonAllocator Class Reference
+
+
+ +

#include <allocator.h>

+
+Inheritance diagram for mlx::core::allocator::CommonAllocator:
+
+
+ + +mlx::core::allocator::Allocator + +
+ + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

virtual Buffer malloc (size_t size, bool allow_swap=false) override
 A general CPU allocator.
 
virtual void free (Buffer buffer) override
 
- Public Member Functions inherited from mlx::core::allocator::Allocator
 Allocator ()=default
 
 Allocator (const Allocator &other)=delete
 
 Allocator (Allocator &&other)=delete
 
Allocatoroperator= (const Allocator &other)=delete
 
Allocatoroperator= (Allocator &&other)=delete
 
virtual ~Allocator ()=default
 
+ + + +

+Friends

Allocatorallocator ()
 
+

Member Function Documentation

+ +

◆ free()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::allocator::CommonAllocator::free (Buffer buffer)
+
+overridevirtual
+
+
+ +

◆ malloc()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual Buffer mlx::core::allocator::CommonAllocator::malloc (size_t size,
bool allow_swap = false )
+
+overridevirtual
+
+ +

A general CPU allocator.

+ +

Implements mlx::core::allocator::Allocator.

+ +
+
+

Friends And Related Symbol Documentation

+ +

◆ allocator

+ +
+
+ + + + + +
+ + + + + + + +
Allocator & allocator ()
+
+friend
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.png b/docs/build/html/classmlx_1_1core_1_1allocator_1_1_common_allocator.png new file mode 100644 index 0000000000000000000000000000000000000000..8b609c844a056c54cc34351188ec30309feafcae GIT binary patch literal 724 zcmeAS@N?(olHy`uVBq!ia0vp^_kcKngBeIVGF$5cDd_;85ZC|z{{xvX-h3_XKeXJ! zK(jz%`k5CG9y|bwo1P6@0+iz{3GxeO0P?}WoN4wI1_q{hPZ!6K3dXl{U-!LM;Av@h zT=nOle@#*UPF2O|w{ptiGlQDgKdV0Y=h|b~r|mLHCbZ9eQrOc9mCVrdADKMMvWvF% zT>l;~;c0aFf4lg##`r_go6ZW)%iP|$;oJ5!&glBrAMeg#JAd|P>6zGruMH-ptgGIW z*(Pf^;aAXEnZpe5i?Ld*zhcBPssHpxKy<+XAO~+TIZa(cXX=&AmNtY(g-)j?p|7LYV zd9upMApH-oOxH%~L?3)TL9gNEy9>u3UIW?k|Gu(kagoiWEn3HwJy(Mmt3G~A>M4w1 zJdpp2nc;C(&Ytt>8T&SVRNFQ2C)@GQr#KnPo>nj^T^H=kxpGS&uber*U3B zU;FmWGn0zEbL*4erD*PsUmiZ$=kMRGnS0th{r05JQwWM%6nZ29=pFv-)Qdp zk|i7cf6H^>IJHyv=Jf^en@74X%1K}L^3%dIp$|G+9-Ueny63dZ`VW&{ZzRB%8wpzY(Q_(+(_eK|4 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::array Member List
+
+
+ +

This is the complete list of members for mlx::core::array, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
array(T val, Dtype dtype=TypeToDtype< T >())mlx::core::arrayexplicit
array(const std::complex< float > &val, Dtype dtype=complex64)mlx::core::arrayexplicit
array(It data, std::vector< int > shape, Dtype dtype=TypeToDtype< typename std::iterator_traits< It >::value_type >())mlx::core::array
array(std::initializer_list< T > data, Dtype dtype=TypeToDtype< T >())mlx::core::array
array(std::initializer_list< float > data)mlx::core::array
array(std::initializer_list< int > data, Dtype dtype)mlx::core::array
array(std::initializer_list< T > data, std::vector< int > shape, Dtype dtype=TypeToDtype< T >())mlx::core::array
array(allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)mlx::core::array
array(const array &other)=defaultmlx::core::array
array(array &&other)=defaultmlx::core::array
array(std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)mlx::core::array
attach_event(Event e) constmlx::core::arrayinline
available enum valuemlx::core::array
begin() constmlx::core::arrayinline
buffer()mlx::core::arrayinline
buffer() constmlx::core::arrayinline
copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)mlx::core::array
copy_shared_buffer(const array &other)mlx::core::array
data()mlx::core::arrayinline
data() constmlx::core::arrayinline
data_shared_ptr() constmlx::core::arrayinline
data_size() constmlx::core::arrayinline
detach()mlx::core::array
dtype() constmlx::core::arrayinline
end() constmlx::core::arrayinline
eval()mlx::core::array
event() constmlx::core::arrayinline
flags() constmlx::core::arrayinline
has_primitive() constmlx::core::arrayinline
id() constmlx::core::arrayinline
inputs() constmlx::core::arrayinline
inputs()mlx::core::arrayinline
is_available() constmlx::core::arrayinline
is_donatable() constmlx::core::arrayinline
is_tracer() constmlx::core::array
item()mlx::core::array
item() constmlx::core::array
itemsize() constmlx::core::arrayinline
make_arrays(std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)mlx::core::arraystatic
move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)mlx::core::array
move_shared_buffer(array other)mlx::core::array
nbytes() constmlx::core::arrayinline
ndim() constmlx::core::arrayinline
operator=(const array &other) &&=deletemlx::core::array
operator=(array &&other) &&=deletemlx::core::array
operator=(array &&other) &=defaultmlx::core::array
operator=(const array &other) &mlx::core::arrayinline
outputs() constmlx::core::arrayinline
overwrite_descriptor(const array &other)mlx::core::arrayinline
primitive() constmlx::core::arrayinline
primitive_id() constmlx::core::arrayinline
primitive_ptr() constmlx::core::arrayinline
scheduled enum valuemlx::core::array
set_data(allocator::Buffer buffer, deleter_t d=allocator::free)mlx::core::array
set_data(allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)mlx::core::array
set_siblings(std::vector< array > siblings, uint16_t position)mlx::core::arrayinline
set_status(Status s) constmlx::core::arrayinline
set_tracer(bool is_tracer)mlx::core::arrayinline
shape() constmlx::core::arrayinline
shape(int dim) constmlx::core::arrayinline
siblings() constmlx::core::arrayinline
siblings()mlx::core::arrayinline
size() constmlx::core::arrayinline
status() constmlx::core::arrayinline
Status enum namemlx::core::array
strides() constmlx::core::arrayinline
strides(int dim) constmlx::core::arrayinline
unscheduled enum valuemlx::core::array
~array()mlx::core::array
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1array.html b/docs/build/html/classmlx_1_1core_1_1array.html new file mode 100644 index 000000000..25f218b83 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1array.html @@ -0,0 +1,2000 @@ + + + + + + + +MLX: mlx::core::array Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+ +
+ +

#include <array.h>

+ + + + + + + + +

+Classes

struct  ArrayIterator
 
struct  Data
 
struct  Flags
 
+ + + +

+Public Types

enum  Status { unscheduled +, scheduled +, available + }
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

template<typename T >
 array (T val, Dtype dtype=TypeToDtype< T >())
 Construct a scalar array with zero dimensions.
 
 array (const std::complex< float > &val, Dtype dtype=complex64)
 
template<typename It >
 array (It data, std::vector< int > shape, Dtype dtype=TypeToDtype< typename std::iterator_traits< It >::value_type >())
 
template<typename T >
 array (std::initializer_list< T > data, Dtype dtype=TypeToDtype< T >())
 
 array (std::initializer_list< float > data)
 
 array (std::initializer_list< int > data, Dtype dtype)
 
template<typename T >
 array (std::initializer_list< T > data, std::vector< int > shape, Dtype dtype=TypeToDtype< T >())
 
 array (allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)
 
arrayoperator= (const array &other) &&=delete
 Assignment to rvalue does not compile.
 
arrayoperator= (array &&other) &&=delete
 
arrayoperator= (array &&other) &=default
 Default copy and move constructors otherwise.
 
 array (const array &other)=default
 
 array (array &&other)=default
 
arrayoperator= (const array &other) &
 
size_t itemsize () const
 The size of the array's datatype in bytes.
 
size_t size () const
 The number of elements in the array.
 
size_t nbytes () const
 The number of bytes in the array.
 
size_t ndim () const
 The number of dimensions of the array.
 
const std::vector< int > & shape () const
 The shape of the array as a vector of integers.
 
int shape (int dim) const
 Get the size of the corresponding dimension.
 
const std::vector< size_t > & strides () const
 The strides of the array.
 
size_t strides (int dim) const
 Get the stride of the corresponding dimension.
 
Dtype dtype () const
 Get the arrays data type.
 
void eval ()
 Evaluate the array.
 
template<typename T >
item ()
 Get the value from a scalar array.
 
template<typename T >
item () const
 
ArrayIterator begin () const
 
ArrayIterator end () const
 
 array (std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
 The following methods should be used with caution.
 
std::uintptr_t id () const
 A unique identifier for an array.
 
std::uintptr_t primitive_id () const
 A unique identifier for an arrays primitive.
 
Primitiveprimitive () const
 The array's primitive.
 
std::shared_ptr< Primitive > & primitive_ptr () const
 A shared pointer to the array's primitive.
 
bool has_primitive () const
 Check if the array has an attached primitive or is a leaf node.
 
const std::vector< array > & inputs () const
 The array's inputs.
 
std::vector< array > & inputs ()
 
bool is_donatable () const
 True indicates the arrays buffer is safe to reuse.
 
const std::vector< array > & siblings () const
 The array's siblings.
 
std::vector< array > & siblings ()
 The array's siblings.
 
void set_siblings (std::vector< array > siblings, uint16_t position)
 
std::vector< arrayoutputs () const
 The outputs of the array's primitive (i.e.
 
void detach ()
 Detach the array from the graph.
 
const Flagsflags () const
 Get the Flags bit-field.
 
size_t data_size () const
 The size (in elements) of the underlying buffer the array points to.
 
allocator::Bufferbuffer ()
 
const allocator::Bufferbuffer () const
 
std::shared_ptr< Datadata_shared_ptr () const
 
template<typename T >
T * data ()
 
template<typename T >
const T * data () const
 
bool is_available () const
 
const Status status () const
 
void set_status (Status s) const
 
Eventevent () const
 
void attach_event (Event e) const
 
void set_tracer (bool is_tracer)
 
bool is_tracer () const
 
void set_data (allocator::Buffer buffer, deleter_t d=allocator::free)
 
void set_data (allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)
 
void copy_shared_buffer (const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
 
void copy_shared_buffer (const array &other)
 
void move_shared_buffer (array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
 
void move_shared_buffer (array other)
 
void overwrite_descriptor (const array &other)
 
 ~array ()
 
+ + + +

+Static Public Member Functions

static std::vector< arraymake_arrays (std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
 
+

Member Enumeration Documentation

+ +

◆ Status

+ +
+
+ + + + +
enum mlx::core::array::Status
+
+ + + + +
Enumerator
unscheduled 
scheduled 
available 
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ array() [1/11]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + +
mlx::core::array::array (T val,
Dtype dtype = TypeToDtype<T>() )
+
+explicit
+
+ +

Construct a scalar array with zero dimensions.

+ +
+
+ +

◆ array() [2/11]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::array::array (const std::complex< float > & val,
Dtype dtype = complex64 )
+
+explicit
+
+ +
+
+ +

◆ array() [3/11]

+ +
+
+
+template<typename It >
+ + + + + + + + + + + + + + + + +
mlx::core::array::array (It data,
std::vector< int > shape,
Dtype dtype = TypeToDtype<typename std::iterator_traits<It>::value_type>() )
+
+ +
+
+ +

◆ array() [4/11]

+ +
+
+
+template<typename T >
+ + + + + + + + + + + +
mlx::core::array::array (std::initializer_list< T > data,
Dtype dtype = TypeToDtype<T>() )
+
+ +
+
+ +

◆ array() [5/11]

+ +
+
+ + + + + + + +
mlx::core::array::array (std::initializer_list< float > data)
+
+ +
+
+ +

◆ array() [6/11]

+ +
+
+ + + + + + + + + + + +
mlx::core::array::array (std::initializer_list< int > data,
Dtype dtype )
+
+ +
+
+ +

◆ array() [7/11]

+ +
+
+
+template<typename T >
+ + + + + + + + + + + + + + + + +
mlx::core::array::array (std::initializer_list< T > data,
std::vector< int > shape,
Dtype dtype = TypeToDtype<T>() )
+
+ +
+
+ +

◆ array() [8/11]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::array::array (allocator::Buffer data,
std::vector< int > shape,
Dtype dtype,
deleter_t deleter = allocator::free )
+
+ +
+
+ +

◆ array() [9/11]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::array::array (const array & other)
+
+default
+
+ +
+
+ +

◆ array() [10/11]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::array::array (array && other)
+
+default
+
+ +
+
+ +

◆ array() [11/11]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::array::array (std::vector< int > shape,
Dtype dtype,
std::shared_ptr< Primitive > primitive,
std::vector< array > inputs )
+
+ +

The following methods should be used with caution.

+

They are intended for use by the backend implementation and the API may change.

+ +
+
+ +

◆ ~array()

+ +
+
+ + + + + + + +
mlx::core::array::~array ()
+
+ +
+
+

Member Function Documentation

+ +

◆ attach_event()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::attach_event (Event e) const
+
+inline
+
+ +
+
+ +

◆ begin()

+ +
+
+ + + + + +
+ + + + + + + +
ArrayIterator mlx::core::array::begin () const
+
+inline
+
+ +
+
+ +

◆ buffer() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
allocator::Buffer & mlx::core::array::buffer ()
+
+inline
+
+ +
+
+ +

◆ buffer() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const allocator::Buffer & mlx::core::array::buffer () const
+
+inline
+
+ +
+
+ +

◆ copy_shared_buffer() [1/2]

+ +
+
+ + + + + + + +
void mlx::core::array::copy_shared_buffer (const array & other)
+
+ +
+
+ +

◆ copy_shared_buffer() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void mlx::core::array::copy_shared_buffer (const array & other,
const std::vector< size_t > & strides,
Flags flags,
size_t data_size,
size_t offset = 0 )
+
+ +
+
+ +

◆ data() [1/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T * mlx::core::array::data ()
+
+inline
+
+ +
+
+ +

◆ data() [2/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T * mlx::core::array::data () const
+
+inline
+
+ +
+
+ +

◆ data_shared_ptr()

+ +
+
+ + + + + +
+ + + + + + + +
std::shared_ptr< Data > mlx::core::array::data_shared_ptr () const
+
+inline
+
+ +
+
+ +

◆ data_size()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::data_size () const
+
+inline
+
+ +

The size (in elements) of the underlying buffer the array points to.

+ +
+
+ +

◆ detach()

+ +
+
+ + + + + + + +
void mlx::core::array::detach ()
+
+ +

Detach the array from the graph.

+ +
+
+ +

◆ dtype()

+ +
+
+ + + + + +
+ + + + + + + +
Dtype mlx::core::array::dtype () const
+
+inline
+
+ +

Get the arrays data type.

+ +
+
+ +

◆ end()

+ +
+
+ + + + + +
+ + + + + + + +
ArrayIterator mlx::core::array::end () const
+
+inline
+
+ +
+
+ +

◆ eval()

+ +
+
+ + + + + + + +
void mlx::core::array::eval ()
+
+ +

Evaluate the array.

+ +
+
+ +

◆ event()

+ +
+
+ + + + + +
+ + + + + + + +
Event & mlx::core::array::event () const
+
+inline
+
+ +
+
+ +

◆ flags()

+ +
+
+ + + + + +
+ + + + + + + +
const Flags & mlx::core::array::flags () const
+
+inline
+
+ +

Get the Flags bit-field.

+ +
+
+ +

◆ has_primitive()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::array::has_primitive () const
+
+inline
+
+ +

Check if the array has an attached primitive or is a leaf node.

+ +
+
+ +

◆ id()

+ +
+
+ + + + + +
+ + + + + + + +
std::uintptr_t mlx::core::array::id () const
+
+inline
+
+ +

A unique identifier for an array.

+ +
+
+ +

◆ inputs() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< array > & mlx::core::array::inputs ()
+
+inline
+
+ +
+
+ +

◆ inputs() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< array > & mlx::core::array::inputs () const
+
+inline
+
+ +

The array's inputs.

+ +
+
+ +

◆ is_available()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::array::is_available () const
+
+inline
+
+ +
+
+ +

◆ is_donatable()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::array::is_donatable () const
+
+inline
+
+ +

True indicates the arrays buffer is safe to reuse.

+ +
+
+ +

◆ is_tracer()

+ +
+
+ + + + + + + +
bool mlx::core::array::is_tracer () const
+
+ +
+
+ +

◆ item() [1/2]

+ +
+
+
+template<typename T >
+ + + + + + + +
T mlx::core::array::item ()
+
+ +

Get the value from a scalar array.

+ +
+
+ +

◆ item() [2/2]

+ +
+
+
+template<typename T >
+ + + + + + + +
T mlx::core::array::item () const
+
+ +
+
+ +

◆ itemsize()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::itemsize () const
+
+inline
+
+ +

The size of the array's datatype in bytes.

+ +
+
+ +

◆ make_arrays()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
static std::vector< array > mlx::core::array::make_arrays (std::vector< std::vector< int > > shapes,
const std::vector< Dtype > & dtypes,
const std::shared_ptr< Primitive > & primitive,
const std::vector< array > & inputs )
+
+static
+
+ +
+
+ +

◆ move_shared_buffer() [1/2]

+ +
+
+ + + + + + + +
void mlx::core::array::move_shared_buffer (array other)
+
+ +
+
+ +

◆ move_shared_buffer() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void mlx::core::array::move_shared_buffer (array other,
const std::vector< size_t > & strides,
Flags flags,
size_t data_size,
size_t offset = 0 )
+
+ +
+
+ +

◆ nbytes()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::nbytes () const
+
+inline
+
+ +

The number of bytes in the array.

+ +
+
+ +

◆ ndim()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::ndim () const
+
+inline
+
+ +

The number of dimensions of the array.

+ +
+
+ +

◆ operator=() [1/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (array && other) &&
+
+delete
+
+ +
+
+ +

◆ operator=() [2/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (array && other) &
+
+default
+
+ +

Default copy and move constructors otherwise.

+ +
+
+ +

◆ operator=() [3/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (const array & other) &
+
+inline
+
+ +
+
+ +

◆ operator=() [4/4]

+ +
+
+ + + + + +
+ + + + + + + +
array & mlx::core::array::operator= (const array & other) &&
+
+delete
+
+ +

Assignment to rvalue does not compile.

+ +
+
+ +

◆ outputs()

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< array > mlx::core::array::outputs () const
+
+inline
+
+ +

The outputs of the array's primitive (i.e.

+

this array and its siblings) in the order the primitive expects.

+ +
+
+ +

◆ overwrite_descriptor()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::overwrite_descriptor (const array & other)
+
+inline
+
+ +
+
+ +

◆ primitive()

+ +
+
+ + + + + +
+ + + + + + + +
Primitive & mlx::core::array::primitive () const
+
+inline
+
+ +

The array's primitive.

+ +
+
+ +

◆ primitive_id()

+ +
+
+ + + + + +
+ + + + + + + +
std::uintptr_t mlx::core::array::primitive_id () const
+
+inline
+
+ +

A unique identifier for an arrays primitive.

+ +
+
+ +

◆ primitive_ptr()

+ +
+
+ + + + + +
+ + + + + + + +
std::shared_ptr< Primitive > & mlx::core::array::primitive_ptr () const
+
+inline
+
+ +

A shared pointer to the array's primitive.

+ +
+
+ +

◆ set_data() [1/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::array::set_data (allocator::Buffer buffer,
deleter_t d = allocator::free )
+
+ +
+
+ +

◆ set_data() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void mlx::core::array::set_data (allocator::Buffer buffer,
size_t data_size,
std::vector< size_t > strides,
Flags flags,
deleter_t d = allocator::free )
+
+ +
+
+ +

◆ set_siblings()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::array::set_siblings (std::vector< array > siblings,
uint16_t position )
+
+inline
+
+ +
+
+ +

◆ set_status()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::set_status (Status s) const
+
+inline
+
+ +
+
+ +

◆ set_tracer()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::array::set_tracer (bool is_tracer)
+
+inline
+
+ +
+
+ +

◆ shape() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< int > & mlx::core::array::shape () const
+
+inline
+
+ +

The shape of the array as a vector of integers.

+ +
+
+ +

◆ shape() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
int mlx::core::array::shape (int dim) const
+
+inline
+
+ +

Get the size of the corresponding dimension.

+

This function supports negative indexing and provides bounds checking.

+ +
+
+ +

◆ siblings() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
std::vector< array > & mlx::core::array::siblings ()
+
+inline
+
+ +

The array's siblings.

+ +
+
+ +

◆ siblings() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< array > & mlx::core::array::siblings () const
+
+inline
+
+ +

The array's siblings.

+ +
+
+ +

◆ size()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::size () const
+
+inline
+
+ +

The number of elements in the array.

+ +
+
+ +

◆ status()

+ +
+
+ + + + + +
+ + + + + + + +
const Status mlx::core::array::status () const
+
+inline
+
+ +
+
+ +

◆ strides() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const std::vector< size_t > & mlx::core::array::strides () const
+
+inline
+
+ +

The strides of the array.

+ +
+
+ +

◆ strides() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::array::strides (int dim) const
+
+inline
+
+ +

Get the stride of the corresponding dimension.

+

This function supports negative indexing and provides bounds checking.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom-members.html new file mode 100644 index 000000000..2acd0e4fb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom-members.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::Custom Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::Custom, including all inherited members.

+ + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0mlx::core::Primitivepure virtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.html new file mode 100644 index 000000000..35642a556 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.html @@ -0,0 +1,306 @@ + + + + + + + +MLX: mlx::core::fast::Custom Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::fast::Custom Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::Custom:
+
+
+ + +mlx::core::Primitive +mlx::core::fast::LayerNorm +mlx::core::fast::LayerNormVJP +mlx::core::fast::RMSNorm +mlx::core::fast::RMSNormVJP +mlx::core::fast::RoPE +mlx::core::fast::ScaledDotProductAttention + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
virtual void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs)=0
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ Custom()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
mlx::core::fast::Custom::Custom (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ jvp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::fast::Custom::jvp (const std::vector< array > & primals,
const std::vector< array > & tangents,
const std::vector< int > & argnums )
+
+overridevirtual
+
+ +

The Jacobian-vector product.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
virtual std::vector< array > mlx::core::fast::Custom::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::Primitive.

+ +

Reimplemented in mlx::core::fast::RMSNorm, mlx::core::fast::LayerNorm, and mlx::core::fast::RoPE.

+ +
+
+ +

◆ vmap()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual std::pair< std::vector< array >, std::vector< int > > mlx::core::fast::Custom::vmap (const std::vector< array > & inputs,
const std::vector< int > & axes )
+
+overridevirtual
+
+ +

The primitive must know how to vectorize itself across the given axes.

+

The output is a pair containing the output arrays representing the vectorized computation and the axes which corresponds to the vectorized dimensions of each output.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_custom.png new file mode 100644 index 0000000000000000000000000000000000000000..6d384f81c061b79bb2110491cefc4b5c09b618b1 GIT binary patch literal 2664 zcmd5;dsGuw8b@STnn}SQL>|!b3uWT~nk~ zFjA_55?vMSL=0BsErZEI+64>{!$Zu55RqpfA%p}nnb`?p|KZqk_Uu2q-?=k$=lk7z z=X=kW9}?teh1-lXGc&XD|Mc(?Gqc|*vH1ndwb-3q+L556X4v~O5zWh0^YFDD*T2Tm~xwkA&T3i*6Jg;uJX8rjgAll1&=YAw1dp*$4HAxHi;&JGJ>XCwM zi{x=`gW$GHyu94pT!Qy4rM0giNk-nH<@2~;q(9El+}ymenl76K7YBr6+qkz*$SRo( zq%L%#Qn?*71wPmXj_Ar5#tY+-? ztjNSEUGa=F_067Hfq%yU2_E^AlCiHd zmC`ybSE`7CqLIfl@y@25nSnA!&v?__4u#1dx)K4HHV~>xzaZh=znT|^E-Gt>PC>Mo zoEM2Lie_OCMJ>B2CkBb!CI3F5T1i^chCk)Icaa~}$oC^@^m% zqoxp2#!?BBI`6)?q^lJZy)=D;j9c5ue$QTbX)!*ov|^%vC0rmESRp;8k84AuiA=_e z>!!$HM{$*7J~WQft+1@Bnes~W+N34?@&prd825b9aHBG=)Q*5he|}mdPbM4X5$(P~ z7P$vzm@?A7j{wJeobFhH(__jvE~K~J>Z?N&IM%aythdnzg=|dH&sQW3r(zl!7%Kv; zGUw|2{tfusxJ`{~jRj$iu=*U9S-H8jE3xlW7P#4%ROY3{Cfycj^EoC?7T`7i2i2{xeCr#(-Z{j5*YRx0|Cp(KSMFKJ zefXY$Lc93)U>F;eErPJj%+EyMUsN*}GY|zvQX+sk^ zP)ASJbpufnnI*ia)Mq1x4;+ate^av4JW^_;6Z`r~m$K#1(O2s#=o%j5pVVuHJ>^k4 zq5hPcU-%6DyR^tB24b01ePQg0y}o>=SRzf9B;Q|bT!&*2TSKUZ$nrs&n082%1swI@L?9kNwj^hzG+JntsNESf|oORsc0gH5}#r@<+Thp>b;sIB7q# z6RsESt;6!h2_}hsH|Me(mNcg~U94o;?}gAncpVet_fI}1CAtb(r9Z0S@p*rGh zp%90@w5w~U2g%6ge+I1V?ZTAJa~)1&f|e*5dWgIV(Im z%9>>_RoWqw{iCyZhQYEfWq^uf&~+b?Yw0p}N^`*$BDYsurO(m3wqNyWjS1JC-mnOx zOeXYN9KNklC%*qjjf*YfaJVIVk{zKFHV_z*I$=R-M5|!XFkRRZ{=|EHB(xrw2`0eu zn?rlOpc_{b6YHZXo>~PIf0<9$dkA)oQ4%tA$3~}csD@Qk^u4?v01Uruz@ah@MM7_K zsabl?4bySIr)i&~2V`7oG@(~PcXO#hf^le%r6B8iQL3)QF|5;N*nL`NL1Z)=NSh~3D=f$!Lu)mouF(O1WA{5`BDm;5nz~}%jWx| z{1%~)u@qtb##!rNH%MDH4^8V|sHdZYou>U^*XtoH} ow4F{U-jB5qWyXRqyB8ttb-~@*k%(KwdO|aQ-=M<}4#s}-FS^1b`2YX_ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm-members.html new file mode 100644 index 000000000..7c29baf14 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm-members.html @@ -0,0 +1,109 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::LayerNorm Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::LayerNorm, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(LayerNorm) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::LayerNorm
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNorminlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNormvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
LayerNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::LayerNorminline
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::LayerNormvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.html new file mode 100644 index 000000000..dd9db97ac --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.html @@ -0,0 +1,327 @@ + + + + + + + +MLX: mlx::core::fast::LayerNorm Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::fast::LayerNorm Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::LayerNorm:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LayerNorm (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
 DEFINE_PRINT (LayerNorm) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LayerNorm()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::LayerNorm::LayerNorm (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::LayerNorm::DEFINE_PRINT (LayerNorm ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNorm::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNorm::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::fast::LayerNorm::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::fast::Custom.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..202a404d0d3ee549d25154cfcb6d4a2a78cd694a GIT binary patch literal 951 zcmeAS@N?(olHy`uVBq!ia0vp^bAh;ngBeJkdMtScNJ$6ygt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcGXdAc};R4~4s`?~M30#9o{ zf6&wK^^eZ-8ni6wwYKiw()UB;)X|!Q2A{<>mzwB$s90(yx=*q@)TdHglcuin({hvD z;+@{L`YJye>yN)nJGVGsd)cJ@C97wgIDYNPb+a|^!esY+%TdgllN#~!;yI<0ebu{7 zrCOUW=_SrCdgXK8ps&|qlJdPbb9t}5<<+#gJa6~CpYw0>>K|Hqw=is0q{_{Y8|UuL z7OyV1`DKk;Sg!Eommr>C96F3xAmb@)#% z+r$5Va+&$*{H^=$r>no-Q~D~*FtG3a`SrOGcPH(QI(B29mD=Rp{hpPsleY1^op!Kt z&xCC=uJ3co+$6hXj#qLy*UWI0%+nR`-`vr&UGl+av$N~-E6Mg7em{8ltTsE{zuK>Q zW?5$VuKinHUv>(zlKOqiVn*e=f9KZi6#ctXI`v!9PMLK3WRYu5iMlGeKhLa7*|~S~ z*G1P}WFLMQ@ILHsb-B#HNq;ZbOxk*+-e}Sn!7`PX%CN9Hs<(7X8Y@HH1aStR(^tyY zUq4e7d|S6j`1RRb@p3VSJp_XWS;6D~n}h{tigm1dwWMY4s#X0Pf{-|1 z^tZZa3a_>Ob$Nfh>b>UAZMV}NT`9h0#`h*9^uJzTYu)*Co6oJkxpUgjqi0sFvQq&H j*Mhxu)Xya)>>tC9I}5dR|L^kxW-$g&S3j3^P6 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::LayerNormVJP Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::LayerNormVJP, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(LayerNormVJP) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::LayerNormVJP
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNormVJPinlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::LayerNormVJPvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
LayerNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::LayerNormVJPinline
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.html new file mode 100644 index 000000000..81a4b6e49 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.html @@ -0,0 +1,284 @@ + + + + + + + +MLX: mlx::core::fast::LayerNormVJP Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::fast::LayerNormVJP Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::LayerNormVJP:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 LayerNormVJP (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
 DEFINE_PRINT (LayerNormVJP) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ LayerNormVJP()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::LayerNormVJP::LayerNormVJP (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::LayerNormVJP::DEFINE_PRINT (LayerNormVJP ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNormVJP::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::LayerNormVJP::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_layer_norm_v_j_p.png new file mode 100644 index 0000000000000000000000000000000000000000..e81afc1426049bda9e0cf122b880207911ac2fe1 GIT binary patch literal 994 zcmeAS@N?(olHy`uVBq!ia0vp^TY1SR%c<=xyZhAIs2~du+B*-tA0mugfbEer>7#NsyJzX3_Dj46+eZ6V(8X-3M zS*t|8*B_l9XsM}I_%`L6_thmYmKI4w$wbtPCih%k$YCIv$D7a`&3vQaz7<1_U&TMS z6BSzGq_m&U>bQ_V!)7wtn86${YAt=k6(Id#(qp^P}z4 z6%920UdavhYh0EpvAmbzL!Wf^UB99TZ;m-HKmYon_kYgU0_M1=_mUvqnJ z{l;t0WxrpOdf@yo>F(QO$#L_Ip6o5%oy>Uevwg($oR{Y(6xsjKVBB?S?E(FB0uA+S z+u1f8T+5hKkj)^yY-^w4v6~Z?$bX!$0C2BW+6T4i?ZA#6WQP>dpYAQ|IED} z{wr*kZ=J6=DWx^qdlGM+uBX|KttvMkT$_~AOh{mtkQ2w{g%g6jCO^EUxfG~=`IJe# z>92R^3-U3adVJd9Z#lz>n5$=Zuc&y za~4)S{Gy`$7i6BtlHpC3+= eb>i6kQJ#0h&d1B9wMGGRBZH@_pUXO@geCyM#`Wd^ literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm-members.html new file mode 100644 index 000000000..ab88c81d7 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm-members.html @@ -0,0 +1,109 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::RMSNorm Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::RMSNorm, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(RMSNorm) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::RMSNorm
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNorminlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNormvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
RMSNorm(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::RMSNorminline
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::RMSNormvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.html new file mode 100644 index 000000000..c476b2c96 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.html @@ -0,0 +1,327 @@ + + + + + + + +MLX: mlx::core::fast::RMSNorm Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::fast::RMSNorm Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::RMSNorm:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RMSNorm (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
 DEFINE_PRINT (RMSNorm) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RMSNorm()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::RMSNorm::RMSNorm (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::RMSNorm::DEFINE_PRINT (RMSNorm ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNorm::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNorm::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::fast::RMSNorm::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::fast::Custom.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..0cb8e0a310f3b7a732e96136a5e706655baa4d09 GIT binary patch literal 927 zcmeAS@N?(olHy`uVBq!ia0vp^(}B2ygBeKPJ=n<#q@)9ULR|m<{|{uoc=NTi|Il&^ z1I+@7>1SR%c<=xyZhAIs2~du+B*-tA0mugfbEer>7#NtfJY5_^Dj46+yIm^eEWZ`yk2NDD^O+jqtBbZzew4C zZ1dlX-74n`Dovm7vyXR{v#-9n+i{Zcx{9fDwte5Y$KNdM#4?lYJ>~MApT2YGPx>MZ zb>sO%ahglF$T57-m1EeT&|PIKm))iNYHOyh=C)Pgcib6vI8`t<%>Tf|VE>Snq3#ej z!=Dy$h98{z3?GE!7#^taVL0HAtl(cuxR<6bhmfb9d&R^h^FJtg+CPtTnj}6Uui>Z0 zZT^$XYB>Izt~VEcpI`esRKK#Xg6~%RnIFk#e{L*aefvX|Zf#EV+sP--9$#SdxO)BC z=y#e+bf10fid~h%tm&M&r)@&zG5+7F?xiOcT9wbP6PdX1sNUM>S345zKW%z$slJoP z`m5I2?R$TzOqkMlY(9;R$Ws# zU-bB!uivh`U;nsbdJO;L$NvR9?$*Dvg?lC9_a_w4JSXqT&E zpEWN2E)$(Q<#FWn*|N(MH#=^RFI%+h=+;QrtqVP)BVRwMnQ5}e>3mpfc=6@eGcKfP zFW+pg_u15Q&7$jjd?sxEo3&@Qw%NSjuM$?}Re%x?$7UP8l`r?^O!>9vWoXf@d&jFd z--;Z!wbi?jQtY>Hzt4@G(^tK6YMH)j)%*`Zp + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::RMSNormVJP Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::RMSNormVJP, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(RMSNormVJP) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::RMSNormVJP
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNormVJPinlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RMSNormVJPvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
RMSNormVJP(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)mlx::core::fast::RMSNormVJPinline
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html new file mode 100644 index 000000000..4b010e9d1 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.html @@ -0,0 +1,284 @@ + + + + + + + +MLX: mlx::core::fast::RMSNormVJP Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::fast::RMSNormVJP Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::RMSNormVJP:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RMSNormVJP (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, float eps)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
 DEFINE_PRINT (RMSNormVJP) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RMSNormVJP()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + +
mlx::core::fast::RMSNormVJP::RMSNormVJP (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
float eps )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::RMSNormVJP::DEFINE_PRINT (RMSNormVJP ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNormVJP::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RMSNormVJP::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_r_m_s_norm_v_j_p.png new file mode 100644 index 0000000000000000000000000000000000000000..39e2b0d047114e829281dfb67b725699313f8a94 GIT binary patch literal 981 zcmeAS@N?(olHy`uVBq!ia0vp^>w&m~gBeIJd+L1>NJ$6ygt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcGDc)B=-R4~4s`?hb1f`Ds& z>?FUv|Ap7*a*N1YS-rf){EH^{$tV-dH&i~mG{c; zv#*s`{}Y*1)3o`$>-_i=_r6WtBlWvk@^!?%n%lFlimvUwZ(7|Um7X0|w`H%#r=HpS z-v1Es^Rix1JJo*2%9+PD3VM8+wEK1T;quq+T<*92yiNbuUb8&<&6V#bHY^kOocHn0 z^>?o)f0g}yrGKAFW&81cXX{>nEEood&FW;aHRgWnP#?a-L{iYUM&umo2Zf=^E>ZuF^giZ z;r8=0o=#n>kgC!9Ucji|W0G-QM7;Fc-z%0BEQvg-WO3z&X1RC4-UIQUxBk{p_kLOS zDtguacem?(CSAICX8FE3Uu0JAe|!7QcDwrYp9M-29@i92Jrt#}rd1`~=8Nvmi}U!u z2CR6zH(n;J^!l~W-*1Gernmi5u@tKB_q@~(^94uzq%X;zeKnVsGBEJRF>Da%{<$wM zn^XVU?w5M1duHX&;b(Z@`GcuJ{vkI5Fk~5i2<~He;7&l`xT?}70fEq~sT}d4q4tNu zL$CH<{JwI3iR`KKwbrx6e{ZfWH46&feYa@A`X78}SLaKvJF~jB^!D|CXZB6X{?=YL z^Ud2C1<}#BZtOC?Z`Ag-^3Uqz*&4<7*H*q$T(xaRwB*+H;TJtE-<3*k`fZY9w&VG2 zjw}A5t3yMiXPu7^%b&FW+2&fO^J&(V#Z&kD*ZBTV{$#29Yo&d + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::RoPE Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::RoPE, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(RoPE) bool is_equivalent(const Primitive &other) const overridemlx::core::fast::RoPE
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RoPEinlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::RoPEvirtual
is_equivalent(const Primitive &other) constmlx::core::Primitiveinlinevirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
RoPE(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)mlx::core::fast::RoPEinline
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::RoPEvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.html new file mode 100644 index 000000000..745b490be --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.html @@ -0,0 +1,352 @@ + + + + + + + +MLX: mlx::core::fast::RoPE Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::fast::RoPE Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::RoPE:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 RoPE (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, int dims, bool traditional, float base, float scale, int offset, bool forward)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
 DEFINE_PRINT (RoPE) bool is_equivalent(const Primitive &other) const override
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual bool is_equivalent (const Primitive &other) const
 Equivalence check defaults to false unless overridden by the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ RoPE()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
mlx::core::fast::RoPE::RoPE (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
int dims,
bool traditional,
float base,
float scale,
int offset,
bool forward )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::fast::RoPE::DEFINE_PRINT (RoPE ) const &
+
+override
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RoPE::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::RoPE::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+overridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ vjp()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
std::vector< array > mlx::core::fast::RoPE::vjp (const std::vector< array > & primals,
const std::vector< array > & cotangents,
const std::vector< int > & argnums,
const std::vector< array > & outputs )
+
+overridevirtual
+
+ +

The vector-Jacobian product.

+ +

Reimplemented from mlx::core::fast::Custom.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_ro_p_e.png new file mode 100644 index 0000000000000000000000000000000000000000..62648d9411cecb26a750282334d3053f4627966a GIT binary patch literal 863 zcmeAS@N?(olHy`uVBq!ia0vp^?LgeY!3-pS<#W6RQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=B-KJY5_^Dj46+ecSh0i>Kvy zbI8ws>wj#?Y%o;Gtt#8EoS0LdTXI}lM zuju`I|A|j#{#82OJ?HzRZTFn#gI0 zU0%5B<0H#8i$y)kWZ$+g<9)5#yZPupiI@Kx9&a%9xBNDlFY>y@mh=AKlHVQOE~8?3 z%e(Kw^5lA-=bz>|SI=DXip$-4aYg*C-v^ILN{U=7dAo1Ill##{6GT+mJp+N1?h&9< z@0&CHSb3Krfib$}aalBTieR7hF(HN31y5=j1*XU_xSV8kAVmEaao1d`lc1sUv(%e! z(xN8{nzJl>RO}6;g`Qk~XFthN{b~J~k4cN(=6>1#c!T>9<&5fyKWcryZfyOUY8M~f z^Q+2d>#KW{U-e9UY!_qew(7~pHCNlNINn-vz;NmsW4_-D0}nr&bnND#Rkx2!cAsvf zv-a{oeZIH*t1@)oPFcKDzqn_%*CjUYIkg`Yip%1ka@1OVJAQGejOeuApWd)a?ap{p zb@<{rjeFDUJgep&KVTRA{@cI*8@DH0@7A>cf8CP}8Wuq6;@rhkg7)Pzd%hO@BRUUqg-6IeRsE?g9tI7#q*ALzbw^R_*NU<_Kulmw{63|dE1WdGQE5K)Eh%_%Xh{h zV#VPnA6te>2et3^mC}y6#jk6VvwZi@oQD#(c!KLg`m+Ci%N)~{~x~l qw|tt+J3TL`E5#le82{e)m)+34>;I=?*W7_wfx*+&&t;ucLK6Tj`k7Y% literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention-members.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention-members.html new file mode 100644 index 000000000..f75d5bd4d --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention-members.html @@ -0,0 +1,110 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::fast::ScaledDotProductAttention Member List
+
+
+ +

This is the complete list of members for mlx::core::fast::ScaledDotProductAttention, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + +
Custom(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)mlx::core::fast::Custominlineexplicit
DEFINE_PRINT(ScaledDotProductAttention)mlx::core::fast::ScaledDotProductAttention
device()mlx::core::Primitiveinline
eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::ScaledDotProductAttentioninlinevirtual
eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) overridemlx::core::fast::ScaledDotProductAttentioninlinevirtual
eval_gpu(const std::vector< array > &inputs, array &out)mlx::core::fast::ScaledDotProductAttention
is_equivalent(const Primitive &other) const overridemlx::core::fast::ScaledDotProductAttentionvirtual
jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) overridemlx::core::fast::Customvirtual
operator=(const Primitive &other)=deletemlx::core::Primitive
operator=(Primitive &&other)=deletemlx::core::Primitive
output_shapes(const std::vector< array > &inputs)mlx::core::Primitivevirtual
Primitive(Stream stream)mlx::core::Primitiveinlineexplicit
Primitive(const Primitive &other)=deletemlx::core::Primitive
Primitive(Primitive &&other)=deletemlx::core::Primitive
print(std::ostream &os)=0mlx::core::Primitivepure virtual
ScaledDotProductAttention(Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)mlx::core::fast::ScaledDotProductAttentioninlineexplicit
stream()mlx::core::Primitiveinline
vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) overridemlx::core::fast::Customvirtual
vmap(const std::vector< array > &inputs, const std::vector< int > &axes) overridemlx::core::fast::Customvirtual
~Primitive()=defaultmlx::core::Primitivevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.html b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.html new file mode 100644 index 000000000..e5b515790 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.html @@ -0,0 +1,333 @@ + + + + + + + +MLX: mlx::core::fast::ScaledDotProductAttention Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::fast::ScaledDotProductAttention Class Reference
+
+
+ +

#include <fast_primitives.h>

+
+Inheritance diagram for mlx::core::fast::ScaledDotProductAttention:
+
+
+ + +mlx::core::fast::Custom +mlx::core::Primitive + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ScaledDotProductAttention (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback, const float scale, const bool needs_mask)
 
void eval_cpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.
 
void eval_gpu (const std::vector< array > &inputs, std::vector< array > &outputs) override
 
void eval_gpu (const std::vector< array > &inputs, array &out)
 
bool is_equivalent (const Primitive &other) const override
 Equivalence check defaults to false unless overridden by the primitive.
 
 DEFINE_PRINT (ScaledDotProductAttention)
 
- Public Member Functions inherited from mlx::core::fast::Custom
 Custom (Stream stream, std::function< std::vector< array >(std::vector< array >)> fallback)
 
virtual std::pair< std::vector< array >, std::vector< int > > vmap (const std::vector< array > &inputs, const std::vector< int > &axes) override
 The primitive must know how to vectorize itself across the given axes.
 
virtual std::vector< arrayjvp (const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
 The Jacobian-vector product.
 
virtual std::vector< arrayvjp (const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
 The vector-Jacobian product.
 
- Public Member Functions inherited from mlx::core::Primitive
 Primitive (Stream stream)
 
const Devicedevice ()
 The device the primitive will run on.
 
const Streamstream ()
 The stream the primitive will run on.
 
virtual void print (std::ostream &os)=0
 Print the primitive.
 
virtual std::vector< std::vector< int > > output_shapes (const std::vector< array > &inputs)
 Get the output shapes of the primitive.
 
virtual ~Primitive ()=default
 
 Primitive (const Primitive &other)=delete
 
 Primitive (Primitive &&other)=delete
 
Primitiveoperator= (const Primitive &other)=delete
 
Primitiveoperator= (Primitive &&other)=delete
 
+

Constructor & Destructor Documentation

+ +

◆ ScaledDotProductAttention()

+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
mlx::core::fast::ScaledDotProductAttention::ScaledDotProductAttention (Stream stream,
std::function< std::vector< array >(std::vector< array >)> fallback,
const float scale,
const bool needs_mask )
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ DEFINE_PRINT()

+ +
+
+ + + + + + + +
mlx::core::fast::ScaledDotProductAttention::DEFINE_PRINT (ScaledDotProductAttention )
+
+ +
+
+ +

◆ eval_cpu()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::ScaledDotProductAttention::eval_cpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the output arrays.

+

To avoid unnecessary allocations, the evaluation function is responsible for allocating space for the array.

+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ eval_gpu() [1/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::fast::ScaledDotProductAttention::eval_gpu (const std::vector< array > & inputs,
array & out )
+
+ +
+
+ +

◆ eval_gpu() [2/2]

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::fast::ScaledDotProductAttention::eval_gpu (const std::vector< array > & inputs,
std::vector< array > & outputs )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::Primitive.

+ +
+
+ +

◆ is_equivalent()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::fast::ScaledDotProductAttention::is_equivalent (const Primitive & other) const
+
+overridevirtual
+
+ +

Equivalence check defaults to false unless overridden by the primitive.

+ +

Reimplemented from mlx::core::Primitive.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.png b/docs/build/html/classmlx_1_1core_1_1fast_1_1_scaled_dot_product_attention.png new file mode 100644 index 0000000000000000000000000000000000000000..65f61e4a0d67c4cbc125902d2b8d2bc41e4f01a9 GIT binary patch literal 1075 zcmeAS@N?(olHy`uVBq!ia0vp^UxB!TgBeJ2T#oesQqloFA+G=b{|7Q(y!l$%e`vXd zfo6fk^fNCWJa_;UH$5A+1SrQ@666=m0OW&#In(Sb3=GWsJY5_^Dj46+yuh+J!*!Stw zZ{GO*xx3=Y?w^&f)sy$R{q)LSX}#HQM#uKGYvi-H@7uli^QBFG7oYDe&yN42G0Em& z={4^uB7R-`cf+nK=R__0d_&UXlaBbit-M9sr!Ofua{u+;-9O&1S0cl9cn_s*O9fw|Y} z{A0f08{DmXCp?)QlYiH5_xXpmh2K|9?>f8o`H_FsdsWw#akqU``6)L2Z`|Z3^A*>6 z8nI>iP6GOUQjcTA>maXAHika}<_sL}9Kxo;SD70;kZv& zaAJW&NNDrKYpYfPZ3(?9(Q|EnE%zqr3;Xp>28CXInEmelwK>1`yfWD0?)uw4d;Y6A z+w3})*vXjh-r)0g$*Z4HmwxS!$VuAU7d!8_O78JD$G3^^+xm55?q|uG?_yc2LM$#{j=6?;$8K~JNC%=2lLl#eQyw*%W+(1-|Xm+ z+>hNi_sTsswN_iT>-gIhw@u5Rw_mn)v)j{Cwu#T?dEnVyz3X4ZX3x%f^)UDFGx6=E zuY&y7ynUj-_VuIBJMHuK#z%yvwpVNKEt0T&b9&9Ao~^QV8*Ut56Kgj6{({8^`qsqX zef@F9Ikow*?>oO%SLaUps?vTt=T6f-p@(H>etVdgmV|F0Im_J)Q2z3&MMr6pTeW$|2!3N>X*ThS3j3^ HP6 + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::FileReader Member List
+
+
+ +

This is the complete list of members for mlx::core::io::FileReader, including all inherited members.

+ + + + + + + + + +
FileReader(std::ifstream is)mlx::core::io::FileReaderinlineexplicit
FileReader(std::string file_path)mlx::core::io::FileReaderinlineexplicit
good() const overridemlx::core::io::FileReaderinlinevirtual
is_open() const overridemlx::core::io::FileReaderinlinevirtual
label() const overridemlx::core::io::FileReaderinlinevirtual
read(char *data, size_t n) overridemlx::core::io::FileReaderinlinevirtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) overridemlx::core::io::FileReaderinlinevirtual
tell() overridemlx::core::io::FileReaderinlinevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.html new file mode 100644 index 000000000..1ae86019b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.html @@ -0,0 +1,346 @@ + + + + + + + +MLX: mlx::core::io::FileReader Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::io::FileReader Class Reference
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::FileReader:
+
+
+ + +mlx::core::io::Reader + +
+ + + + + + + + + + + + + + + + + + +

+Public Member Functions

 FileReader (std::ifstream is)
 
 FileReader (std::string file_path)
 
bool is_open () const override
 
bool good () const override
 
size_t tell () override
 
void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
 
void read (char *data, size_t n) override
 
std::string label () const override
 
+

Constructor & Destructor Documentation

+ +

◆ FileReader() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileReader::FileReader (std::ifstream is)
+
+inlineexplicit
+
+ +
+
+ +

◆ FileReader() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileReader::FileReader (std::string file_path)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileReader::good () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileReader::is_open () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
std::string mlx::core::io::FileReader::label () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ read()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileReader::read (char * data,
size_t n )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileReader::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::io::FileReader::tell ()
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Reader.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_reader.png new file mode 100644 index 0000000000000000000000000000000000000000..0a31de8a66d013e62136badc3528dffbeead85e6 GIT binary patch literal 642 zcmV-|0)737P)vTJr#LVva2S`&=-}Ys|Ns9r%~qrU000SeQchC<|NsC0|NsC0Hv*f~0006F zNklRq;%F|M7%XCZ0VWtHp?bf+c@N0~TR)D~@w7Y8FruOA}W z4jzy?r}wra8D5fh1xS1SZ>rSAb^d39;r>2#`aP0s?~9V(U)-Z!B1x}DNXI4mi94R{ zJ^S}Q?hWK7jwGG(4}~NnB@#(SN+gnult?5QX+f&B)~N&lzoQob%%%B|Ua#qx=0jSD zW@DNVDJKp9VkHs)L`oz8h?GbG5Gj!WAW|X$K%_(hfJli10Fe?203sz407Ob80Em=G z01)X-8(}nfNNRr$#;Rfkr;ya|W zpQ}9>-|c++*4)gITS)i5?dMlV#xB|Ww`^*Mv{-hmBH3lO)yJ567js)f^q8z2&Zq9G zWU~yXE|Pk=YV>X6kxrvY*q?U1xaUZ=SqCIvrK8tfq_bMyk4Sb}K2*sf-y!+-_Q~w6 z21u^GFRFWgasS^-F_N2^b1du_^ + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::FileWriter Member List
+
+
+ +

This is the complete list of members for mlx::core::io::FileWriter, including all inherited members.

+ + + + + + + + + +
FileWriter(std::ofstream os)mlx::core::io::FileWriterinlineexplicit
FileWriter(std::string file_path)mlx::core::io::FileWriterinlineexplicit
good() const overridemlx::core::io::FileWriterinlinevirtual
is_open() const overridemlx::core::io::FileWriterinlinevirtual
label() const overridemlx::core::io::FileWriterinlinevirtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg) overridemlx::core::io::FileWriterinlinevirtual
tell() overridemlx::core::io::FileWriterinlinevirtual
write(const char *data, size_t n) overridemlx::core::io::FileWriterinlinevirtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.html new file mode 100644 index 000000000..7367b3b87 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.html @@ -0,0 +1,346 @@ + + + + + + + +MLX: mlx::core::io::FileWriter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::io::FileWriter Class Reference
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::FileWriter:
+
+
+ + +mlx::core::io::Writer + +
+ + + + + + + + + + + + + + + + + + +

+Public Member Functions

 FileWriter (std::ofstream os)
 
 FileWriter (std::string file_path)
 
bool is_open () const override
 
bool good () const override
 
size_t tell () override
 
void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg) override
 
void write (const char *data, size_t n) override
 
std::string label () const override
 
+

Constructor & Destructor Documentation

+ +

◆ FileWriter() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileWriter::FileWriter (std::ofstream os)
+
+inlineexplicit
+
+ +
+
+ +

◆ FileWriter() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::io::FileWriter::FileWriter (std::string file_path)
+
+inlineexplicit
+
+ +
+
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileWriter::good () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
bool mlx::core::io::FileWriter::is_open () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
std::string mlx::core::io::FileWriter::label () const
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileWriter::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::io::FileWriter::tell ()
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+ +

◆ write()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
void mlx::core::io::FileWriter::write (const char * data,
size_t n )
+
+inlineoverridevirtual
+
+ +

Implements mlx::core::io::Writer.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_file_writer.png new file mode 100644 index 0000000000000000000000000000000000000000..3f1679897f3bd4c71bcf9d4034cc9a56cc8686ac GIT binary patch literal 612 zcmeAS@N?(olHy`uVBq!ia0vp^-9Q|`!3-qj?8+2@lyrbki0l9V|AEXGZ@!lHA6jl< zpjjX>{mhF84;}!;P0xlc0m^Zf1o;Is0Qq2G&NTZ90|VoGPZ!6K3dXl{Z}%Nm;Bh(L z9CG%5{YSGWX$G0w-r6mgD&5C)P=(##mA02=s!)JO5!aRH0= z=>~G8|MVcQnam77=CL#Q9CFd?KW<|rywfIh*~FPARnH4D^mvFfbS&X!aJj_Fp!AZ7 zVPXkmgU2og2bEX`g-Lo00-i|H|1TUplOvQc$%OT#lIO&diAy|6DlhButl$6Z&zq<$ z>*(qaGuOwem>%WbUnW`qtYBm8wy>C-xlb!sCuaZs<8!?%IDU83tI+G`_HHhW{FN#t zQXZXjXR+OpwY8U@{yxY1<^7v$N8MX0XDykR_Rn`8Ps}y#6{eA2I!}MwxzT&KUi-}x zPp|AN`IM0=UPiJUvDWj4Ne tbUyGOyT>I_m}_@=1U-3~lxD_m|J=N|>cy3~slfEW;OXk;vd$@?2>>FtAhiGh literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_reader-members.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader-members.html new file mode 100644 index 000000000..98887ab2b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader-members.html @@ -0,0 +1,96 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::Reader Member List
+
+
+ +

This is the complete list of members for mlx::core::io::Reader, including all inherited members.

+ + + + + + + +
good() const =0mlx::core::io::Readerpure virtual
is_open() const =0mlx::core::io::Readerpure virtual
label() const =0mlx::core::io::Readerpure virtual
read(char *data, size_t n)=0mlx::core::io::Readerpure virtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0mlx::core::io::Readerpure virtual
tell()=0mlx::core::io::Readerpure virtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.html new file mode 100644 index 000000000..2cd2194a8 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.html @@ -0,0 +1,291 @@ + + + + + + + +MLX: mlx::core::io::Reader Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::io::Reader Class Referenceabstract
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::Reader:
+
+
+ + +mlx::core::io::FileReader + +
+ + + + + + + + + + + + + + +

+Public Member Functions

virtual bool is_open () const =0
 
virtual bool good () const =0
 
virtual size_t tell ()=0
 
virtual void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
 
virtual void read (char *data, size_t n)=0
 
virtual std::string label () const =0
 
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Reader::good () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Reader::is_open () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
virtual std::string mlx::core::io::Reader::label () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ read()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Reader::read (char * data,
size_t n )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Reader::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
virtual size_t mlx::core::io::Reader::tell ()
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileReader.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_reader.png new file mode 100644 index 0000000000000000000000000000000000000000..a28b37482fca676db602abbbf3248d03471d1b56 GIT binary patch literal 647 zcmeAS@N?(olHy`uVBq!ia0vp^lYlsYgBeI3ZM_4cqyv0HT>t<74`jZ0^R=}9&~gg{ z%>s$(XI?yb@Bk=odNyncP>!=C$S;@y$Oi*+rrB2*7?^}ST^vIy7~jsl+xOalqvd$= z%30s*A5FHZPgt&4y-oaD*C}Onkr(MauXMdMQ}q^lT=MZ!@$8!7F-hbl-^=z5d#h*5 zc}kVc|IzI+=}y+e>6UdpA@>yXv#Sj_w>iI$E;cUBP`ta_RBpC$Uro42_wBw=mEC8) zeKTBkFST0o@h!_5G0(Etu#^8B%ClBxOzA9k&AoW6qR_ze(=XLYZyzwOERNbPd-8+V zCT)jXuj_-C>aF{1FU<4(`^wAjC&pCAPxoPcTfM8`R()Svu=~->aE%Dru1?er@4o`@O{Ope`DSr{E_$k z_V#CKeRrI8eb2SuZ)H?=^4k4PD`hVw-Q38PIsN86yJJ^&&$unM=y=N80=v?1)!C^Q zrHVIChc3>S?k|z@J@jwO + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::io::Writer Member List
+
+
+ +

This is the complete list of members for mlx::core::io::Writer, including all inherited members.

+ + + + + + + +
good() const =0mlx::core::io::Writerpure virtual
is_open() const =0mlx::core::io::Writerpure virtual
label() const =0mlx::core::io::Writerpure virtual
seek(int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0mlx::core::io::Writerpure virtual
tell()=0mlx::core::io::Writerpure virtual
write(const char *data, size_t n)=0mlx::core::io::Writerpure virtual
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.html b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.html new file mode 100644 index 000000000..991f1a823 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.html @@ -0,0 +1,291 @@ + + + + + + + +MLX: mlx::core::io::Writer Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::io::Writer Class Referenceabstract
+
+
+ +

#include <load.h>

+
+Inheritance diagram for mlx::core::io::Writer:
+
+
+ + +mlx::core::io::FileWriter + +
+ + + + + + + + + + + + + + +

+Public Member Functions

virtual bool is_open () const =0
 
virtual bool good () const =0
 
virtual size_t tell ()=0
 
virtual void seek (int64_t off, std::ios_base::seekdir way=std::ios_base::beg)=0
 
virtual void write (const char *data, size_t n)=0
 
virtual std::string label () const =0
 
+

Member Function Documentation

+ +

◆ good()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Writer::good () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ is_open()

+ +
+
+ + + + + +
+ + + + + + + +
virtual bool mlx::core::io::Writer::is_open () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ label()

+ +
+
+ + + + + +
+ + + + + + + +
virtual std::string mlx::core::io::Writer::label () const
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ seek()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Writer::seek (int64_t off,
std::ios_base::seekdir way = std::ios_base::beg )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ tell()

+ +
+
+ + + + + +
+ + + + + + + +
virtual size_t mlx::core::io::Writer::tell ()
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+ +

◆ write()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual void mlx::core::io::Writer::write (const char * data,
size_t n )
+
+pure virtual
+
+ +

Implemented in mlx::core::io::FileWriter.

+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.png b/docs/build/html/classmlx_1_1core_1_1io_1_1_writer.png new file mode 100644 index 0000000000000000000000000000000000000000..70dfa5f6882969907fbe781de58cc2ae5b58d14a GIT binary patch literal 619 zcmV-x0+juUP)vTJr#LVva2S`&=-}Ys|Ns9r%~qrU000SeQchC<|NsC0|NsC0Hv*f~0005@ zNklCg)KDm5p*@v&~op<~58~>q;q^9&Pw{IM{*gk_=aW7c@0{B}E+v?}H`}RS-)?Pnc zk}D&Zb&yPh;%;kWl|2{v1hCL;eqtq)56* zQY2j@DUvRd6j@x%%;o?9p7R2LNtnpx^_oX!B9}#F8kw0~@*Dxch;#vfNV)((BwYX? zk}d!cNf!W!qzeE<(ggq_=>hv2RAG(xM%41%ol=2u9lAMBcA<0O( zkYpsf7)ky;eYn+MeSUVWwVii+@R`RhB%9)fU%QMybGXGFN&2Ksk)>*6EAJdI>eGyf)E1}@tb)cn|= zT;87>yR_?an2lUc?Q0jG<`?VYT|c?>*&DX(#-+*KK3{J4XqYdTE_dtV4$R;j`NPp4 zmXN#05sNPV9OeH2lDr~aNHUTxBpFE;l8h`arIa!Ue*g)}bh{@89OnQ4002ovPDHLk FV1nahCCUH* literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_device-members.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device-members.html new file mode 100644 index 000000000..0d522e267 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device-members.html @@ -0,0 +1,112 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::metal::Device Member List
+
+
+ +

This is the complete list of members for mlx::core::metal::Device, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + +
argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) constmlx::core::metal::Device
commit_command_buffer(int index)mlx::core::metal::Device
Device()mlx::core::metal::Device
Device(const Device &)=deletemlx::core::metal::Device
end_encoding(int index)mlx::core::metal::Device
get_command_buffer(int index)mlx::core::metal::Device
get_command_buffer_ops(int index)mlx::core::metal::Device
get_command_encoder(int index)mlx::core::metal::Device
get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})mlx::core::metal::Device
get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})mlx::core::metal::Device
get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})mlx::core::metal::Device
get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})mlx::core::metal::Device
get_library(const std::string &name)mlx::core::metal::Device
get_library(const std::string &name, const std::string &source_string, bool cache=true)mlx::core::metal::Device
get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)mlx::core::metal::Device
increment_command_buffer_ops(int index)mlx::core::metal::Device
mtl_device()mlx::core::metal::Deviceinline
new_queue(int index)mlx::core::metal::Device
operator=(const Device &)=deletemlx::core::metal::Device
register_library(const std::string &lib_name, const std::string &lib_path)mlx::core::metal::Device
register_library(const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)mlx::core::metal::Device
~Device()mlx::core::metal::Device
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_device.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device.html new file mode 100644 index 000000000..9d6e636ad --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_device.html @@ -0,0 +1,635 @@ + + + + + + + +MLX: mlx::core::metal::Device Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::metal::Device Class Reference
+
+
+ +

#include <device.h>

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Device ()
 
 Device (const Device &)=delete
 
Deviceoperator= (const Device &)=delete
 
 ~Device ()
 
MTL::Device * mtl_device ()
 
void new_queue (int index)
 
MTL::CommandBuffer * get_command_buffer (int index)
 
int get_command_buffer_ops (int index)
 
void increment_command_buffer_ops (int index)
 
void commit_command_buffer (int index)
 
CommandEncoderget_command_encoder (int index)
 
void end_encoding (int index)
 
void register_library (const std::string &lib_name, const std::string &lib_path)
 
void register_library (const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)
 
MTL::Library * get_library (const std::string &name)
 
MTL::Library * get_library (const std::string &name, const std::string &source_string, bool cache=true)
 
MTL::Library * get_library (const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
 
MTL::Function * get_function (const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
 
MTL::Function * get_function (const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
 
MTL::ComputePipelineState * get_kernel (const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
 
MTL::ComputePipelineState * get_kernel (const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
 
MTL::ArgumentEncoder * argument_encoder (const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
 
+

Constructor & Destructor Documentation

+ +

◆ Device() [1/2]

+ +
+
+ + + + + + + +
mlx::core::metal::Device::Device ()
+
+ +
+
+ +

◆ Device() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::metal::Device::Device (const Device & )
+
+delete
+
+ +
+
+ +

◆ ~Device()

+ +
+
+ + + + + + + +
mlx::core::metal::Device::~Device ()
+
+ +
+
+

Member Function Documentation

+ +

◆ argument_encoder()

+ +
+
+ + + + + + + +
MTL::ArgumentEncoder * mlx::core::metal::Device::argument_encoder (const std::vector< MTL::ArgumentDescriptor * > & arg_descs) const
+
+ +
+
+ +

◆ commit_command_buffer()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::commit_command_buffer (int index)
+
+ +
+
+ +

◆ end_encoding()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::end_encoding (int index)
+
+ +
+
+ +

◆ get_command_buffer()

+ +
+
+ + + + + + + +
MTL::CommandBuffer * mlx::core::metal::Device::get_command_buffer (int index)
+
+ +
+
+ +

◆ get_command_buffer_ops()

+ +
+
+ + + + + + + +
int mlx::core::metal::Device::get_command_buffer_ops (int index)
+
+ +
+
+ +

◆ get_command_encoder()

+ +
+
+ + + + + + + +
CommandEncoder & mlx::core::metal::Device::get_command_encoder (int index)
+
+ +
+
+ +

◆ get_function() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
MTL::Function * mlx::core::metal::Device::get_function (const std::string & base_name,
const std::string & lib_name = "mlx",
const std::string & specialized_name = "",
const MTLFCList & func_consts = {} )
+
+ +
+
+ +

◆ get_function() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + +
MTL::Function * mlx::core::metal::Device::get_function (const std::string & base_name,
MTL::Library * mtl_lib,
const std::string & specialized_name = "",
const MTLFCList & func_consts = {} )
+
+ +
+
+ +

◆ get_kernel() [1/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
MTL::ComputePipelineState * mlx::core::metal::Device::get_kernel (const std::string & base_name,
const std::string & lib_name = "mlx",
const std::string & hash_name = "",
const MTLFCList & func_consts = {},
const std::vector< MTL::Function * > & linked_functions = {} )
+
+ +
+
+ +

◆ get_kernel() [2/2]

+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
MTL::ComputePipelineState * mlx::core::metal::Device::get_kernel (const std::string & base_name,
MTL::Library * mtl_lib,
const std::string & hash_name = "",
const MTLFCList & func_consts = {},
const std::vector< MTL::Function * > & linked_functions = {} )
+
+ +
+
+ +

◆ get_library() [1/3]

+ +
+
+ + + + + + + +
MTL::Library * mlx::core::metal::Device::get_library (const std::string & name)
+
+ +
+
+ +

◆ get_library() [2/3]

+ +
+
+ + + + + + + + + + + + + + + + +
MTL::Library * mlx::core::metal::Device::get_library (const std::string & name,
const MTL::StitchedLibraryDescriptor * desc,
bool cache = true )
+
+ +
+
+ +

◆ get_library() [3/3]

+ +
+
+ + + + + + + + + + + + + + + + +
MTL::Library * mlx::core::metal::Device::get_library (const std::string & name,
const std::string & source_string,
bool cache = true )
+
+ +
+
+ +

◆ increment_command_buffer_ops()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::increment_command_buffer_ops (int index)
+
+ +
+
+ +

◆ mtl_device()

+ +
+
+ + + + + +
+ + + + + + + +
MTL::Device * mlx::core::metal::Device::mtl_device ()
+
+inline
+
+ +
+
+ +

◆ new_queue()

+ +
+
+ + + + + + + +
void mlx::core::metal::Device::new_queue (int index)
+
+ +
+
+ +

◆ operator=()

+ +
+
+ + + + + +
+ + + + + + + +
Device & mlx::core::metal::Device::operator= (const Device & )
+
+delete
+
+ +
+
+ +

◆ register_library() [1/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::metal::Device::register_library (const std::string & lib_name,
const std::function< std::string(const std::string &)> & lib_path_func = get_colocated_mtllib_path )
+
+ +
+
+ +

◆ register_library() [2/2]

+ +
+
+ + + + + + + + + + + +
void mlx::core::metal::Device::register_library (const std::string & lib_name,
const std::string & lib_path )
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator-members.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator-members.html new file mode 100644 index 000000000..6ddf10230 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator-members.html @@ -0,0 +1,106 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::metal::MetalAllocator Member List
+
+ + + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.html b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.html new file mode 100644 index 000000000..b19a22a49 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.html @@ -0,0 +1,388 @@ + + + + + + + +MLX: mlx::core::metal::MetalAllocator Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::metal::MetalAllocator Class Reference
+
+
+ +

#include <allocator.h>

+
+Inheritance diagram for mlx::core::metal::MetalAllocator:
+
+
+ + +mlx::core::allocator::Allocator + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

virtual Buffer malloc (size_t size, bool allow_swap=false) override
 Allocator for Metal GPUs.
 
virtual void free (Buffer buffer) override
 
size_t get_active_memory ()
 
size_t get_peak_memory ()
 
void reset_peak_memory ()
 
size_t get_cache_memory ()
 
size_t set_cache_limit (size_t limit)
 
size_t set_memory_limit (size_t limit, bool relaxed)
 
void clear_cache ()
 
- Public Member Functions inherited from mlx::core::allocator::Allocator
 Allocator ()=default
 
 Allocator (const Allocator &other)=delete
 
 Allocator (Allocator &&other)=delete
 
Allocatoroperator= (const Allocator &other)=delete
 
Allocatoroperator= (Allocator &&other)=delete
 
virtual ~Allocator ()=default
 
+ + + +

+Friends

MetalAllocatorallocator ()
 
+

Member Function Documentation

+ +

◆ clear_cache()

+ +
+
+ + + + + + + +
void mlx::core::metal::MetalAllocator::clear_cache ()
+
+ +
+
+ +

◆ free()

+ +
+
+ + + + + +
+ + + + + + + +
virtual void mlx::core::metal::MetalAllocator::free (Buffer buffer)
+
+overridevirtual
+
+
+ +

◆ get_active_memory()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::get_active_memory ()
+
+inline
+
+ +
+
+ +

◆ get_cache_memory()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::get_cache_memory ()
+
+inline
+
+ +
+
+ +

◆ get_peak_memory()

+ +
+
+ + + + + +
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::get_peak_memory ()
+
+inline
+
+ +
+
+ +

◆ malloc()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
virtual Buffer mlx::core::metal::MetalAllocator::malloc (size_t size,
bool allow_swap = false )
+
+overridevirtual
+
+ +

Allocator for Metal GPUs.

+ +

Implements mlx::core::allocator::Allocator.

+ +
+
+ +

◆ reset_peak_memory()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::metal::MetalAllocator::reset_peak_memory ()
+
+inline
+
+ +
+
+ +

◆ set_cache_limit()

+ +
+
+ + + + + + + +
size_t mlx::core::metal::MetalAllocator::set_cache_limit (size_t limit)
+
+ +
+
+ +

◆ set_memory_limit()

+ +
+
+ + + + + + + + + + + +
size_t mlx::core::metal::MetalAllocator::set_memory_limit (size_t limit,
bool relaxed )
+
+ +
+
+

Friends And Related Symbol Documentation

+ +

◆ allocator

+ +
+
+ + + + + +
+ + + + + + + +
MetalAllocator & allocator ()
+
+friend
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.png b/docs/build/html/classmlx_1_1core_1_1metal_1_1_metal_allocator.png new file mode 100644 index 0000000000000000000000000000000000000000..c82190d6222eabf8781259a9fd6bce166fb65348 GIT binary patch literal 680 zcmeAS@N?(olHy`uVBq!ia0vp^JApWWgBeH`S*WN0Dd_;85ZC|z{{xvX-h3_XKeXJ! zK(jz%`k5CG9y|bwo1P6@0+iz{3GxeO0P?}WoN4wI1_mZ$PZ!6K3dXl{Z=QRtz|+F6 z9Qyg6f6b!W`}cR8h}Q%;RrHCH(mt*tLBV+4CqR!ZW_YmwB`ZU4R@A5K*TZ4^}I%GsLr6G64O_cp%)7j{&Hc;X~*I=7#l6G7Lb|P?anQ_tNy$ zRGBmbU&x)KP`&u5uH%2ZtEfyYBl}^{IA?#TNbY$046)K3RU-{mI!a zVP}6gm6V#Czn4N)NY-ye9EPsU+i^GGyFu}m3jly6oaR$ KpUXO@geCxqC`zON literal 0 HcmV?d00001 diff --git a/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence-members.html b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence-members.html new file mode 100644 index 000000000..a3e82b9eb --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::random::KeySequence Member List
+
+
+ +

This is the complete list of members for mlx::core::random::KeySequence, including all inherited members.

+ + + + + +
default_()mlx::core::random::KeySequenceinlinestatic
KeySequence(uint64_t seed)mlx::core::random::KeySequenceexplicit
next()mlx::core::random::KeySequence
seed(uint64_t seed)mlx::core::random::KeySequence
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence.html b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence.html new file mode 100644 index 000000000..4662bb44f --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1random_1_1_key_sequence.html @@ -0,0 +1,197 @@ + + + + + + + +MLX: mlx::core::random::KeySequence Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::random::KeySequence Class Reference
+
+
+ +

#include <random.h>

+ + + + + + + + +

+Public Member Functions

 KeySequence (uint64_t seed)
 
void seed (uint64_t seed)
 
array next ()
 
+ + + +

+Static Public Member Functions

static KeySequencedefault_ ()
 
+

Constructor & Destructor Documentation

+ +

◆ KeySequence()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::random::KeySequence::KeySequence (uint64_t seed)
+
+explicit
+
+ +
+
+

Member Function Documentation

+ +

◆ default_()

+ +
+
+ + + + + +
+ + + + + + + +
static KeySequence & mlx::core::random::KeySequence::default_ ()
+
+inlinestatic
+
+ +
+
+ +

◆ next()

+ +
+
+ + + + + + + +
array mlx::core::random::KeySequence::next ()
+
+ +
+
+ +

◆ seed()

+ +
+
+ + + + + + + +
void mlx::core::random::KeySequence::seed (uint64_t seed)
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler-members.html b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler-members.html new file mode 100644 index 000000000..7a145e396 --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler-members.html @@ -0,0 +1,104 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
mlx::core::scheduler::Scheduler Member List
+
+ + + + + diff --git a/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler.html b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler.html new file mode 100644 index 000000000..b9ec8053b --- /dev/null +++ b/docs/build/html/classmlx_1_1core_1_1scheduler_1_1_scheduler.html @@ -0,0 +1,478 @@ + + + + + + + +MLX: mlx::core::scheduler::Scheduler Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
mlx::core::scheduler::Scheduler Class Reference
+
+
+ +

#include <scheduler.h>

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 Scheduler ()
 
 Scheduler (const Scheduler &)=delete
 
 Scheduler (Scheduler &&)=delete
 
Scheduleroperator= (const Scheduler &)=delete
 
Scheduleroperator= (Scheduler &&)=delete
 
Stream new_stream (const Device &d)
 
template<typename F >
void enqueue (const Stream &stream, F &&f)
 
Stream get_default_stream (const Device &d)
 
void set_default_stream (const Stream &s)
 
void notify_new_task (const Stream &stream)
 
void notify_task_completion (const Stream &stream)
 
int n_active_tasks () const
 
void wait_for_one ()
 
 ~Scheduler ()
 
+

Constructor & Destructor Documentation

+ +

◆ Scheduler() [1/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::Scheduler ()
+
+inline
+
+ +
+
+ +

◆ Scheduler() [2/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::Scheduler (const Scheduler & )
+
+delete
+
+ +
+
+ +

◆ Scheduler() [3/3]

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::Scheduler (Scheduler && )
+
+delete
+
+ +
+
+ +

◆ ~Scheduler()

+ +
+
+ + + + + +
+ + + + + + + +
mlx::core::scheduler::Scheduler::~Scheduler ()
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ enqueue()

+ +
+
+
+template<typename F >
+ + + + + + + + + + + +
void mlx::core::scheduler::Scheduler::enqueue (const Stream & stream,
F && f )
+
+ +
+
+ +

◆ get_default_stream()

+ +
+
+ + + + + +
+ + + + + + + +
Stream mlx::core::scheduler::Scheduler::get_default_stream (const Device & d)
+
+inline
+
+ +
+
+ +

◆ n_active_tasks()

+ +
+
+ + + + + +
+ + + + + + + +
int mlx::core::scheduler::Scheduler::n_active_tasks () const
+
+inline
+
+ +
+
+ +

◆ new_stream()

+ +
+
+ + + + + +
+ + + + + + + +
Stream mlx::core::scheduler::Scheduler::new_stream (const Device & d)
+
+inline
+
+ +
+
+ +

◆ notify_new_task()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::notify_new_task (const Stream & stream)
+
+inline
+
+ +
+
+ +

◆ notify_task_completion()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::notify_task_completion (const Stream & stream)
+
+inline
+
+ +
+
+ +

◆ operator=() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
Scheduler & mlx::core::scheduler::Scheduler::operator= (const Scheduler & )
+
+delete
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
Scheduler & mlx::core::scheduler::Scheduler::operator= (Scheduler && )
+
+delete
+
+ +
+
+ +

◆ set_default_stream()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::set_default_stream (const Stream & s)
+
+inline
+
+ +
+
+ +

◆ wait_for_one()

+ +
+
+ + + + + +
+ + + + + + + +
void mlx::core::scheduler::Scheduler::wait_for_one ()
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23-members.html new file mode 100644 index 000000000..91ebfde05 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dcst23< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dcst23< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool ortho, int type, bool cosine) constpocketfft::detail::T_dcst23< T0 >inline
length() constpocketfft::detail::T_dcst23< T0 >inline
T_dcst23(size_t length)pocketfft::detail::T_dcst23< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23.html new file mode 100644 index 000000000..c7c55e2aa --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst23.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dcst23< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::T_dcst23< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dcst23 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool ortho, int type, bool cosine) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dcst23()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dcst23< T0 >::T_dcst23 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dcst23< T0 >::exec (T c[],
T0 fct,
bool ortho,
int type,
bool cosine ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dcst23< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4-members.html new file mode 100644 index 000000000..3df2ebce9 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dcst4< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dcst4< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool, int, bool cosine) constpocketfft::detail::T_dcst4< T0 >inline
length() constpocketfft::detail::T_dcst4< T0 >inline
T_dcst4(size_t length)pocketfft::detail::T_dcst4< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4.html new file mode 100644 index 000000000..8d03410d4 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dcst4.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dcst4< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::T_dcst4< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dcst4 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool, int, bool cosine) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dcst4()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dcst4< T0 >::T_dcst4 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dcst4< T0 >::exec (T c[],
T0 fct,
bool ,
int ,
bool cosine ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dcst4< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1-members.html new file mode 100644 index 000000000..306b02f42 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dct1< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dct1< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool ortho, int, bool) constpocketfft::detail::T_dct1< T0 >inline
length() constpocketfft::detail::T_dct1< T0 >inline
T_dct1(size_t length)pocketfft::detail::T_dct1< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1.html new file mode 100644 index 000000000..ad8c292fc --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dct1.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dct1< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::T_dct1< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dct1 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool ortho, int, bool) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dct1()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dct1< T0 >::T_dct1 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dct1< T0 >::exec (T c[],
T0 fct,
bool ortho,
int ,
bool  ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dct1< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1-members.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1-members.html new file mode 100644 index 000000000..5f72dd266 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::T_dst1< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::T_dst1< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool, int, bool) constpocketfft::detail::T_dst1< T0 >inline
length() constpocketfft::detail::T_dst1< T0 >inline
T_dst1(size_t length)pocketfft::detail::T_dst1< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1.html b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1.html new file mode 100644 index 000000000..678dde855 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1_t__dst1.html @@ -0,0 +1,210 @@ + + + + + + + +MLX: pocketfft::detail::T_dst1< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::T_dst1< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 T_dst1 (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool, int, bool) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ T_dst1()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::T_dst1< T0 >::T_dst1 (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
void pocketfft::detail::T_dst1< T0 >::exec (T c[],
T0 fct,
bool ,
int ,
bool  ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::T_dst1< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr-members.html b/docs/build/html/classpocketfft_1_1detail_1_1arr-members.html new file mode 100644 index 000000000..0a83715c7 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr-members.html @@ -0,0 +1,100 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::arr< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::arr< T >, including all inherited members.

+ + + + + + + + + + + +
arr()pocketfft::detail::arr< T >inline
arr(size_t n)pocketfft::detail::arr< T >inline
arr(arr &&other)pocketfft::detail::arr< T >inline
data()pocketfft::detail::arr< T >inline
data() constpocketfft::detail::arr< T >inline
operator[](size_t idx)pocketfft::detail::arr< T >inline
operator[](size_t idx) constpocketfft::detail::arr< T >inline
resize(size_t n)pocketfft::detail::arr< T >inline
size() constpocketfft::detail::arr< T >inline
~arr()pocketfft::detail::arr< T >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr.html b/docs/build/html/classpocketfft_1_1detail_1_1arr.html new file mode 100644 index 000000000..d81388ba2 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr.html @@ -0,0 +1,391 @@ + + + + + + + +MLX: pocketfft::detail::arr< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::arr< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 arr ()
 
 arr (size_t n)
 
 arr (arr &&other)
 
 ~arr ()
 
void resize (size_t n)
 
Toperator[] (size_t idx)
 
const Toperator[] (size_t idx) const
 
Tdata ()
 
const Tdata () const
 
size_t size () const
 
+

Constructor & Destructor Documentation

+ +

◆ arr() [1/3]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::arr ()
+
+inline
+
+ +
+
+ +

◆ arr() [2/3]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::arr (size_t n)
+
+inline
+
+ +
+
+ +

◆ arr() [3/3]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::arr (arr< T > && other)
+
+inline
+
+ +
+
+ +

◆ ~arr()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::arr< T >::~arr ()
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ data() [1/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T * pocketfft::detail::arr< T >::data ()
+
+inline
+
+ +
+
+ +

◆ data() [2/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T * pocketfft::detail::arr< T >::data () const
+
+inline
+
+ +
+
+ +

◆ operator[]() [1/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T & pocketfft::detail::arr< T >::operator[] (size_t idx)
+
+inline
+
+ +
+
+ +

◆ operator[]() [2/2]

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T & pocketfft::detail::arr< T >::operator[] (size_t idx) const
+
+inline
+
+ +
+
+ +

◆ resize()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
void pocketfft::detail::arr< T >::resize (size_t n)
+
+inline
+
+ +
+
+ +

◆ size()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr< T >::size () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr__info-members.html b/docs/build/html/classpocketfft_1_1detail_1_1arr__info-members.html new file mode 100644 index 000000000..38eb5333c --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr__info-members.html @@ -0,0 +1,99 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::arr_info Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::arr_info, including all inherited members.

+ + + + + + + + + + +
arr_info(const shape_t &shape_, const stride_t &stride_)pocketfft::detail::arr_infoinline
ndim() constpocketfft::detail::arr_infoinline
shape() constpocketfft::detail::arr_infoinline
shape(size_t i) constpocketfft::detail::arr_infoinline
shppocketfft::detail::arr_infoprotected
size() constpocketfft::detail::arr_infoinline
strpocketfft::detail::arr_infoprotected
stride() constpocketfft::detail::arr_infoinline
stride(size_t i) constpocketfft::detail::arr_infoinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr__info.html b/docs/build/html/classpocketfft_1_1detail_1_1arr__info.html new file mode 100644 index 000000000..08df02140 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1arr__info.html @@ -0,0 +1,357 @@ + + + + + + + +MLX: pocketfft::detail::arr_info Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::arr_info Class Reference
+
+
+ +

#include <pocketfft.h>

+
+Inheritance diagram for pocketfft::detail::arr_info:
+
+
+ + +pocketfft::detail::cndarr< T > +pocketfft::detail::ndarr< T > + +
+ + + + + + + + + + + + + + + + +

+Public Member Functions

 arr_info (const shape_t &shape_, const stride_t &stride_)
 
size_t ndim () const
 
size_t size () const
 
const shape_tshape () const
 
size_t shape (size_t i) const
 
const stride_tstride () const
 
const ptrdiff_t & stride (size_t i) const
 
+ + + + + +

+Protected Attributes

shape_t shp
 
stride_t str
 
+

Constructor & Destructor Documentation

+ +

◆ arr_info()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
pocketfft::detail::arr_info::arr_info (const shape_t & shape_,
const stride_t & stride_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ ndim()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr_info::ndim () const
+
+inline
+
+ +
+
+ +

◆ shape() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const shape_t & pocketfft::detail::arr_info::shape () const
+
+inline
+
+ +
+
+ +

◆ shape() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr_info::shape (size_t i) const
+
+inline
+
+ +
+
+ +

◆ size()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::arr_info::size () const
+
+inline
+
+ +
+
+ +

◆ stride() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
const stride_t & pocketfft::detail::arr_info::stride () const
+
+inline
+
+ +
+
+ +

◆ stride() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
const ptrdiff_t & pocketfft::detail::arr_info::stride (size_t i) const
+
+inline
+
+ +
+
+

Member Data Documentation

+ +

◆ shp

+ +
+
+ + + + + +
+ + + + +
shape_t pocketfft::detail::arr_info::shp
+
+protected
+
+ +
+
+ +

◆ str

+ +
+
+ + + + + +
+ + + + +
stride_t pocketfft::detail::arr_info::str
+
+protected
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1arr__info.png b/docs/build/html/classpocketfft_1_1detail_1_1arr__info.png new file mode 100644 index 0000000000000000000000000000000000000000..8cf1d3ced302b69089d970e4947289599f2076c4 GIT binary patch literal 1040 zcmeAS@N?(olHy`uVBq!ia0vp^tAV(KgBeJ!=X<>wNJ$6ygt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcIB_H=O!sbG9N_w_nuO`eu+ zr>;N${P)!ExbIY+R+HBJq~jIC#=G*zChRXbx6_$LU8$(0C)C9y{+O21qV><$%)NX5 zKy1JFIKiCjGks0?rbBN^mtI}pBYyekXU@MV zlZ`G#@$}tPYwOv3)pWnfPMyiS(|R99AClR5&UbNy(B@X*%;=T#%eJq6U$lE$rEH}4 z6V=z}ve#XIbVcmti*<*j%<4}^={S|IYpg$bt8K^0*DvgEWo%may0q^9iCrtN$0yeo z^DoiO>#KRWZSvCxE!9qQu1s5&7h%7P-Nj{}|9|7`JuBq*z5lVNtD|Ni!YkcZEuWtg zV0iJGm7)EG!nNzKWo|S_=wC@r@K4%W&A?FSe2k0Xmr)qQh3Ola8us;wGJHv0!>~X< znU$edQk&t$=?F%Lcr+#Vs@JcZ{{7x+?s+Qx6E;p-vae^V%Ga&wrH`Q@cKzmroLkS9 ze<=PlA!Ygb2 z^NT{{ZQg92`$EqYO7oBPwWWGGRZ98}D)6x{(Ht%b5!{=YOy4{wc z{Iq8M=cu}hBU0M6$ALljoip3Yc4n!C*I^C5{D^jwH=nG2y^q}OA^*HIcKzqy?$>+y*H8DV zOAk(dB4`vg>B@AFWqFZzEeftbzGSlhMXA~93YD9ukMeu^{hDF=VBNJ3X_+d!%MS0_ zoDsP`g?0Cqbw8$lOtLyWYqOV@RrkTdju_v=IaL$Zyp;rpV+1T5BX6HF`Lx{B{Pw*s zRd4LpD`#(v-}G+Ju{~w`ay#a{FOAzRKX22qru!Skx-Z<_wCwWa$o4P1i>>yg%Wm3u z`t^(bC;!GH!mw{m>|V2!`PX0nlJkhZGw1Q8T=mm|oZnewcbm<6w)D*Me-G!nPP(t1 ztm;{N_uuhJPo|%U@C=GS<~8Zpu86(g|MC1~KXhPyhlEhoR$$&`@O1TaS?83{1OT-s B79aos literal 0 HcmV?d00001 diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cfftp-members.html b/docs/build/html/classpocketfft_1_1detail_1_1cfftp-members.html new file mode 100644 index 000000000..65676b3eb --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cfftp-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::cfftp< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::cfftp< T0 >, including all inherited members.

+ + + +
cfftp(size_t length_)pocketfft::detail::cfftp< T0 >inline
exec(T c[], T0 fct, bool fwd) constpocketfft::detail::cfftp< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cfftp.html b/docs/build/html/classpocketfft_1_1detail_1_1cfftp.html new file mode 100644 index 000000000..77e4017fe --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cfftp.html @@ -0,0 +1,171 @@ + + + + + + + +MLX: pocketfft::detail::cfftp< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::cfftp< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + +

+Public Member Functions

template<typename T >
void exec (T c[], T0 fct, bool fwd) const
 
 cfftp (size_t length_)
 
+

Constructor & Destructor Documentation

+ +

◆ cfftp()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::cfftp< T0 >::cfftp (size_t length_)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::cfftp< T0 >::exec (T c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cndarr-members.html b/docs/build/html/classpocketfft_1_1detail_1_1cndarr-members.html new file mode 100644 index 000000000..1d2ac7085 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cndarr-members.html @@ -0,0 +1,102 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::cndarr< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::cndarr< T >, including all inherited members.

+ + + + + + + + + + + + + +
arr_info(const shape_t &shape_, const stride_t &stride_)pocketfft::detail::arr_infoinline
cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)pocketfft::detail::cndarr< T >inline
dpocketfft::detail::cndarr< T >protected
ndim() constpocketfft::detail::arr_infoinline
operator[](ptrdiff_t ofs) constpocketfft::detail::cndarr< T >inline
shape() constpocketfft::detail::arr_infoinline
shape(size_t i) constpocketfft::detail::arr_infoinline
shppocketfft::detail::arr_infoprotected
size() constpocketfft::detail::arr_infoinline
strpocketfft::detail::arr_infoprotected
stride() constpocketfft::detail::arr_infoinline
stride(size_t i) constpocketfft::detail::arr_infoinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cndarr.html b/docs/build/html/classpocketfft_1_1detail_1_1cndarr.html new file mode 100644 index 000000000..742ba67f6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1cndarr.html @@ -0,0 +1,229 @@ + + + + + + + +MLX: pocketfft::detail::cndarr< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::cndarr< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+
+Inheritance diagram for pocketfft::detail::cndarr< T >:
+
+
+ + +pocketfft::detail::arr_info +pocketfft::detail::ndarr< T > + +
+ + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 cndarr (const void *data_, const shape_t &shape_, const stride_t &stride_)
 
const Toperator[] (ptrdiff_t ofs) const
 
- Public Member Functions inherited from pocketfft::detail::arr_info
 arr_info (const shape_t &shape_, const stride_t &stride_)
 
size_t ndim () const
 
size_t size () const
 
const shape_tshape () const
 
size_t shape (size_t i) const
 
const stride_tstride () const
 
const ptrdiff_t & stride (size_t i) const
 
+ + + + + + + + +

+Protected Attributes

const chard
 
- Protected Attributes inherited from pocketfft::detail::arr_info
shape_t shp
 
stride_t str
 
+

Constructor & Destructor Documentation

+ +

◆ cndarr()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
pocketfft::detail::cndarr< T >::cndarr (const void * data_,
const shape_t & shape_,
const stride_t & stride_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator[]()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
const T & pocketfft::detail::cndarr< T >::operator[] (ptrdiff_t ofs) const
+
+inline
+
+ +
+
+

Member Data Documentation

+ +

◆ d

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + +
const char* pocketfft::detail::cndarr< T >::d
+
+protected
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1cndarr.png b/docs/build/html/classpocketfft_1_1detail_1_1cndarr.png new file mode 100644 index 0000000000000000000000000000000000000000..268d77cafbb0952a67dfc3b44e99ea88660ca187 GIT binary patch literal 1037 zcmeAS@N?(olHy`uVBq!ia0vp^tAV(KgBeJ!=X<>wNJ$6ygt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcIB@N{tusbG9N_w_nuO`eu+ zr>;N${P)P@oz*Q#KbCIh(YJz6;P&4{o}b+B_u9LBP^oO`nX2+Nb&bl(_0QMj$?W*R zG3%N5-g9j#?xt+^iJnE*=5Kbpb5-h|POpl5M)Ix{-jiK=>;7t0XwJ2tf8)HFrSjUN z4zi|?AM_(`3z%2R{&ko3y;i#S+>>KyRJwG6KL@J=a(YEvIJ4S{dypq}sFHT1=I>aC2V)$hg#&BW!My7^+ zJ)#U>Qr9pn&`)M%s6|s^@%)Zrldzzr(=jcjMf%C2f)}sP+xq)$`h!Et8METQm=)$< zn(V)~!(;jR2<3NKy!q4Alr~-X8>XkLmF*Gs<|*sa^-a?+Ns32@6kmws=`;1e%dOXM zF5&Ic0Z5Y#${gjH->fJs68~x-#<~d?~=i(%KMx7GH*OB{c?YE*+jc@uXn{g|K7i5 zc2zC&&8ykFXFdPAYajb3F?04EWw%dOxjnK~^jmlBc$V#{F0t^H9Vb?-V{>uY7Xk9z z`oHySI!=5)5#j0r3=JhEU}&DRET?X0R=nT4EAC$PdcBHCHcp@zGYXq@<+|=_X?8uE zre!}}*KdwIs(dm|uhygJ((yHdl0CLr?0VI zZ@7MLzVxCe)pM`EYhS)*<%9AySCz`Pule~!p$;Wcm3yI z<=4xAF|jUAdeiyWrM3J|bo0&?zTGzY>xv`QO?|P_zNdqp1^@qWPU+mdKI;Vst0EhYf<^TWy literal 0 HcmV?d00001 diff --git a/docs/build/html/classpocketfft_1_1detail_1_1fftblue-members.html b/docs/build/html/classpocketfft_1_1detail_1_1fftblue-members.html new file mode 100644 index 000000000..ba6b7bcdd --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1fftblue-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::fftblue< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::fftblue< T0 >, including all inherited members.

+ + + + +
exec(cmplx< T > c[], T0 fct, bool fwd) constpocketfft::detail::fftblue< T0 >inline
exec_r(T c[], T0 fct, bool fwd)pocketfft::detail::fftblue< T0 >inline
fftblue(size_t length)pocketfft::detail::fftblue< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1fftblue.html b/docs/build/html/classpocketfft_1_1detail_1_1fftblue.html new file mode 100644 index 000000000..519a05ae5 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1fftblue.html @@ -0,0 +1,212 @@ + + + + + + + +MLX: pocketfft::detail::fftblue< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::fftblue< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + +

+Public Member Functions

 fftblue (size_t length)
 
template<typename T >
void exec (cmplx< T > c[], T0 fct, bool fwd) const
 
template<typename T >
void exec_r (T c[], T0 fct, bool fwd)
 
+

Constructor & Destructor Documentation

+ +

◆ fftblue()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::fftblue< T0 >::fftblue (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::fftblue< T0 >::exec (cmplx< T > c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+ +

◆ exec_r()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::fftblue< T0 >::exec_r (T c[],
T0 fct,
bool fwd )
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1multi__iter-members.html b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter-members.html new file mode 100644 index 000000000..fc671bdf6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter-members.html @@ -0,0 +1,101 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::multi_iter< N > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::multi_iter< N >, including all inherited members.

+ + + + + + + + + + + + +
advance(size_t n)pocketfft::detail::multi_iter< N >inline
iofs(size_t i) constpocketfft::detail::multi_iter< N >inline
iofs(size_t j, size_t i) constpocketfft::detail::multi_iter< N >inline
length_in() constpocketfft::detail::multi_iter< N >inline
length_out() constpocketfft::detail::multi_iter< N >inline
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_)pocketfft::detail::multi_iter< N >inline
oofs(size_t i) constpocketfft::detail::multi_iter< N >inline
oofs(size_t j, size_t i) constpocketfft::detail::multi_iter< N >inline
remaining() constpocketfft::detail::multi_iter< N >inline
stride_in() constpocketfft::detail::multi_iter< N >inline
stride_out() constpocketfft::detail::multi_iter< N >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1multi__iter.html b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter.html new file mode 100644 index 000000000..1558f0d92 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1multi__iter.html @@ -0,0 +1,437 @@ + + + + + + + +MLX: pocketfft::detail::multi_iter< N > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::multi_iter< N > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 multi_iter (const arr_info &iarr_, const arr_info &oarr_, size_t idim_)
 
void advance (size_t n)
 
ptrdiff_t iofs (size_t i) const
 
ptrdiff_t iofs (size_t j, size_t i) const
 
ptrdiff_t oofs (size_t i) const
 
ptrdiff_t oofs (size_t j, size_t i) const
 
size_t length_in () const
 
size_t length_out () const
 
ptrdiff_t stride_in () const
 
ptrdiff_t stride_out () const
 
size_t remaining () const
 
+

Constructor & Destructor Documentation

+ +

◆ multi_iter()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + + + + + + + + + + +
pocketfft::detail::multi_iter< N >::multi_iter (const arr_info & iarr_,
const arr_info & oarr_,
size_t idim_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ advance()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
void pocketfft::detail::multi_iter< N >::advance (size_t n)
+
+inline
+
+ +
+
+ +

◆ iofs() [1/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::iofs (size_t i) const
+
+inline
+
+ +
+
+ +

◆ iofs() [2/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::iofs (size_t j,
size_t i ) const
+
+inline
+
+ +
+
+ +

◆ length_in()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::multi_iter< N >::length_in () const
+
+inline
+
+ +
+
+ +

◆ length_out()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::multi_iter< N >::length_out () const
+
+inline
+
+ +
+
+ +

◆ oofs() [1/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::oofs (size_t i) const
+
+inline
+
+ +
+
+ +

◆ oofs() [2/2]

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::oofs (size_t j,
size_t i ) const
+
+inline
+
+ +
+
+ +

◆ remaining()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::multi_iter< N >::remaining () const
+
+inline
+
+ +
+
+ +

◆ stride_in()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::stride_in () const
+
+inline
+
+ +
+
+ +

◆ stride_out()

+ +
+
+
+template<size_t N>
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::multi_iter< N >::stride_out () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1ndarr-members.html b/docs/build/html/classpocketfft_1_1detail_1_1ndarr-members.html new file mode 100644 index 000000000..695ada829 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1ndarr-members.html @@ -0,0 +1,104 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::ndarr< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::ndarr< T >, including all inherited members.

+ + + + + + + + + + + + + + + +
arr_info(const shape_t &shape_, const stride_t &stride_)pocketfft::detail::arr_infoinline
cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)pocketfft::detail::cndarr< T >inline
dpocketfft::detail::cndarr< T >protected
ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)pocketfft::detail::ndarr< T >inline
ndim() constpocketfft::detail::arr_infoinline
operator[](ptrdiff_t ofs)pocketfft::detail::ndarr< T >inline
pocketfft::detail::cndarr::operator[](ptrdiff_t ofs) constpocketfft::detail::cndarr< T >inline
shape() constpocketfft::detail::arr_infoinline
shape(size_t i) constpocketfft::detail::arr_infoinline
shppocketfft::detail::arr_infoprotected
size() constpocketfft::detail::arr_infoinline
strpocketfft::detail::arr_infoprotected
stride() constpocketfft::detail::arr_infoinline
stride(size_t i) constpocketfft::detail::arr_infoinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1ndarr.html b/docs/build/html/classpocketfft_1_1detail_1_1ndarr.html new file mode 100644 index 000000000..515d2c773 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1ndarr.html @@ -0,0 +1,209 @@ + + + + + + + +MLX: pocketfft::detail::ndarr< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::ndarr< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+
+Inheritance diagram for pocketfft::detail::ndarr< T >:
+
+
+ + +pocketfft::detail::cndarr< T > +pocketfft::detail::arr_info + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

 ndarr (void *data_, const shape_t &shape_, const stride_t &stride_)
 
Toperator[] (ptrdiff_t ofs)
 
- Public Member Functions inherited from pocketfft::detail::cndarr< T >
 cndarr (const void *data_, const shape_t &shape_, const stride_t &stride_)
 
const Toperator[] (ptrdiff_t ofs) const
 
- Public Member Functions inherited from pocketfft::detail::arr_info
 arr_info (const shape_t &shape_, const stride_t &stride_)
 
size_t ndim () const
 
size_t size () const
 
const shape_tshape () const
 
size_t shape (size_t i) const
 
const stride_tstride () const
 
const ptrdiff_t & stride (size_t i) const
 
+ + + + + + + + + +

+Additional Inherited Members

- Protected Attributes inherited from pocketfft::detail::cndarr< T >
const chard
 
- Protected Attributes inherited from pocketfft::detail::arr_info
shape_t shp
 
stride_t str
 
+

Constructor & Destructor Documentation

+ +

◆ ndarr()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
pocketfft::detail::ndarr< T >::ndarr (void * data_,
const shape_t & shape_,
const stride_t & stride_ )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator[]()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
T & pocketfft::detail::ndarr< T >::operator[] (ptrdiff_t ofs)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1ndarr.png b/docs/build/html/classpocketfft_1_1detail_1_1ndarr.png new file mode 100644 index 0000000000000000000000000000000000000000..96f688ccda63c89376608b377f575b96caf2c9a8 GIT binary patch literal 1032 zcmeAS@N?(olHy`uVBq!ia0vp^tAV(KgBeJ!=X<>wNJ$6ygt-3y{~ySF@#br3|Doj; z2ATyD)6cwk@ZbSZ-1KbN5}+JsNswPK1CS2}=1jA%FfcGr_H=O!sbG9N_jTVBEuNNc zr>viU=l@7i5Gc-^bnok%OqKo>CIYvA8npfFmPo$i(xI}mbE%Yxnoxo?72sml{`nO{((L zu1`|0w_3eadVAM5;_@^7$knc|`b-d@$$d*Y72|GuWX>aXcF^X~_C z9({G<`nq{**wya}^X$20S$g=_MO%kS_m}^lX&QfpfB*K{IL}YqARp+bt9n+;t}#5e zqn+VEoIb-17wL`r;(i~if(LAEb-I7&c7b%+#=Vjwr*Mr)wAz^wU`x zsx7q{Zk&!}JP>z|i{TrZk`BXTzAYY;qEAG6UWz;CHR)T_=C(S4zT*oM>n>&Ic&yr1 zF*WYF{iK{PlG%smu6=m+l1I@Z>1*X;Q@?tI8Ci%=sppFGOOMJ4JnNWz(ej#&sr(L` zOK*NERBn~NCKD$-P2u_-gHx)8Q7l2`sn@Gg{uRH>y8YEDz4C43>W5QGpI&|Tb=FnI zk|mSgSH`{Wc3-peQuvyyN~QBR%v-*ELz(+Z>8B5Qv!}O(##R3}%@%(7c+a`lrGM(q z?^Wc68(@ri%)0oW8lMa~yws!NC6IS*QE8djC5+cz(?=eR%HLhqO$k zOIv>D)?_b?>`z&GaOb)|pj3BxR(F<`RQI`!-8D0p-TbVm6kdG{oZwzV6P!ZP)|r{l z?60rcwwJYP_VsyNZ}~@t$JNG3$IC8yQa$hbx1-M2Vm@eJ%bK{$``YI(BG>F{m5NN~ zeu%!NC%qb&0K4z}FO7Tt)v9iPdhC_#^S7?7R{Fgs?m0Le%7D_LyXswGp(UAnGb2w` zypq2csibsY8|0(wZ~sRKPSl?W^4RG}SC==X8`=a8$^GJ&{kh(%_`AdtV4h|0boFyt I=akR{0FJ5)UjP6A literal 0 HcmV?d00001 diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c-members.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c-members.html new file mode 100644 index 000000000..cd88f36a7 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::pocketfft_c< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::pocketfft_c< T0 >, including all inherited members.

+ + + + +
exec(cmplx< T > c[], T0 fct, bool fwd) constpocketfft::detail::pocketfft_c< T0 >inline
length() constpocketfft::detail::pocketfft_c< T0 >inline
pocketfft_c(size_t length)pocketfft::detail::pocketfft_c< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c.html new file mode 100644 index 000000000..8eac32ad6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__c.html @@ -0,0 +1,200 @@ + + + + + + + +MLX: pocketfft::detail::pocketfft_c< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::pocketfft_c< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 pocketfft_c (size_t length)
 
template<typename T >
void exec (cmplx< T > c[], T0 fct, bool fwd) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ pocketfft_c()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::pocketfft_c< T0 >::pocketfft_c (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::pocketfft_c< T0 >::exec (cmplx< T > c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::pocketfft_c< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r-members.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r-members.html new file mode 100644 index 000000000..14d5e1b2f --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::pocketfft_r< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::pocketfft_r< T0 >, including all inherited members.

+ + + + +
exec(T c[], T0 fct, bool fwd) constpocketfft::detail::pocketfft_r< T0 >inline
length() constpocketfft::detail::pocketfft_r< T0 >inline
pocketfft_r(size_t length)pocketfft::detail::pocketfft_r< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r.html b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r.html new file mode 100644 index 000000000..7cf6d4e6d --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1pocketfft__r.html @@ -0,0 +1,200 @@ + + + + + + + +MLX: pocketfft::detail::pocketfft_r< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::pocketfft_r< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + +

+Public Member Functions

 pocketfft_r (size_t length)
 
template<typename T >
void exec (T c[], T0 fct, bool fwd) const
 
size_t length () const
 
+

Constructor & Destructor Documentation

+ +

◆ pocketfft_r()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::pocketfft_r< T0 >::pocketfft_r (size_t length)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::pocketfft_r< T0 >::exec (T c[],
T0 fct,
bool fwd ) const
+
+inline
+
+ +
+
+ +

◆ length()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::pocketfft_r< T0 >::length () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rev__iter-members.html b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter-members.html new file mode 100644 index 000000000..39dfa4688 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter-members.html @@ -0,0 +1,95 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::rev_iter Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::rev_iter, including all inherited members.

+ + + + + + +
advance()pocketfft::detail::rev_iterinline
ofs() constpocketfft::detail::rev_iterinline
remaining() constpocketfft::detail::rev_iterinline
rev_iter(const arr_info &arr_, const shape_t &axes)pocketfft::detail::rev_iterinline
rev_ofs() constpocketfft::detail::rev_iterinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rev__iter.html b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter.html new file mode 100644 index 000000000..3775ffc19 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rev__iter.html @@ -0,0 +1,240 @@ + + + + + + + +MLX: pocketfft::detail::rev_iter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::rev_iter Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + +

+Public Member Functions

 rev_iter (const arr_info &arr_, const shape_t &axes)
 
void advance ()
 
ptrdiff_t ofs () const
 
ptrdiff_t rev_ofs () const
 
size_t remaining () const
 
+

Constructor & Destructor Documentation

+ +

◆ rev_iter()

+ +
+
+ + + + + +
+ + + + + + + + + + + +
pocketfft::detail::rev_iter::rev_iter (const arr_info & arr_,
const shape_t & axes )
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ advance()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::rev_iter::advance ()
+
+inline
+
+ +
+
+ +

◆ ofs()

+ +
+
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::rev_iter::ofs () const
+
+inline
+
+ +
+
+ +

◆ remaining()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::rev_iter::remaining () const
+
+inline
+
+ +
+
+ +

◆ rev_ofs()

+ +
+
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::rev_iter::rev_ofs () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rfftp-members.html b/docs/build/html/classpocketfft_1_1detail_1_1rfftp-members.html new file mode 100644 index 000000000..778f37aa0 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rfftp-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::rfftp< T0 > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::rfftp< T0 >, including all inherited members.

+ + + +
exec(T c[], T0 fct, bool r2hc) constpocketfft::detail::rfftp< T0 >inline
rfftp(size_t length_)pocketfft::detail::rfftp< T0 >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1rfftp.html b/docs/build/html/classpocketfft_1_1detail_1_1rfftp.html new file mode 100644 index 000000000..deddc38d0 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1rfftp.html @@ -0,0 +1,171 @@ + + + + + + + +MLX: pocketfft::detail::rfftp< T0 > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::rfftp< T0 > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + +

+Public Member Functions

template<typename T >
void exec (T c[], T0 fct, bool r2hc) const
 
 rfftp (size_t length_)
 
+

Constructor & Destructor Documentation

+ +

◆ rfftp()

+ +
+
+
+template<typename T0 >
+ + + + + +
+ + + + + + + +
pocketfft::detail::rfftp< T0 >::rfftp (size_t length_)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ exec()

+ +
+
+
+template<typename T0 >
+
+template<typename T >
+ + + + + +
+ + + + + + + + + + + + + + + + +
void pocketfft::detail::rfftp< T0 >::exec (T c[],
T0 fct,
bool r2hc ) const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1simple__iter-members.html b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter-members.html new file mode 100644 index 000000000..0bd71a255 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::simple_iter Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::simple_iter, including all inherited members.

+ + + + + +
advance()pocketfft::detail::simple_iterinline
ofs() constpocketfft::detail::simple_iterinline
remaining() constpocketfft::detail::simple_iterinline
simple_iter(const arr_info &arr_)pocketfft::detail::simple_iterinline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1simple__iter.html b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter.html new file mode 100644 index 000000000..973bde1d8 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1simple__iter.html @@ -0,0 +1,209 @@ + + + + + + + +MLX: pocketfft::detail::simple_iter Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::simple_iter Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + +

+Public Member Functions

 simple_iter (const arr_info &arr_)
 
void advance ()
 
ptrdiff_t ofs () const
 
size_t remaining () const
 
+

Constructor & Destructor Documentation

+ +

◆ simple_iter()

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::simple_iter::simple_iter (const arr_info & arr_)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ advance()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::simple_iter::advance ()
+
+inline
+
+ +
+
+ +

◆ ofs()

+ +
+
+ + + + + +
+ + + + + + + +
ptrdiff_t pocketfft::detail::simple_iter::ofs () const
+
+inline
+
+ +
+
+ +

◆ remaining()

+ +
+
+ + + + + +
+ + + + + + + +
size_t pocketfft::detail::simple_iter::remaining () const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn-members.html b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn-members.html new file mode 100644 index 000000000..8220ef462 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn-members.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::sincos_2pibyn< T > Member List
+
+
+ +

This is the complete list of members for pocketfft::detail::sincos_2pibyn< T >, including all inherited members.

+ + + +
operator[](size_t idx) constpocketfft::detail::sincos_2pibyn< T >inline
sincos_2pibyn(size_t n)pocketfft::detail::sincos_2pibyn< T >inline
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn.html b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn.html new file mode 100644 index 000000000..c39332681 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1sincos__2pibyn.html @@ -0,0 +1,159 @@ + + + + + + + +MLX: pocketfft::detail::sincos_2pibyn< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::sincos_2pibyn< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + +

+Public Member Functions

 sincos_2pibyn (size_t n)
 
cmplx< Toperator[] (size_t idx) const
 
+

Constructor & Destructor Documentation

+ +

◆ sincos_2pibyn()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
pocketfft::detail::sincos_2pibyn< T >::sincos_2pibyn (size_t n)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator[]()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
cmplx< T > pocketfft::detail::sincos_2pibyn< T >::operator[] (size_t idx) const
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue-members.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue-members.html new file mode 100644 index 000000000..119b98e56 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue-members.html @@ -0,0 +1,93 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::threading::concurrent_queue< T > Member List
+
+ + + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html new file mode 100644 index 000000000..624dd5302 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html @@ -0,0 +1,187 @@ + + + + + + + +MLX: pocketfft::detail::threading::concurrent_queue< T > Class Template Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::threading::concurrent_queue< T > Class Template Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + +

+Public Member Functions

void push (T val)
 
bool try_pop (T &val)
 
bool empty () const
 
+

Member Function Documentation

+ +

◆ empty()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
bool pocketfft::detail::threading::concurrent_queue< T >::empty () const
+
+inline
+
+ +
+
+ +

◆ push()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::concurrent_queue< T >::push (T val)
+
+inline
+
+ +
+
+ +

◆ try_pop()

+ +
+
+
+template<typename T >
+ + + + + +
+ + + + + + + +
bool pocketfft::detail::threading::concurrent_queue< T >::try_pop (T & val)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch-members.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch-members.html new file mode 100644 index 000000000..1a1651f32 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch-members.html @@ -0,0 +1,94 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::threading::latch Member List
+
+ + + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch.html new file mode 100644 index 000000000..b4bc2ced1 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1latch.html @@ -0,0 +1,209 @@ + + + + + + + +MLX: pocketfft::detail::threading::latch Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::threading::latch Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + +

+Public Member Functions

 latch (size_t n)
 
void count_down ()
 
void wait ()
 
bool is_ready ()
 
+

Constructor & Destructor Documentation

+ +

◆ latch()

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::latch::latch (size_t n)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ count_down()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::latch::count_down ()
+
+inline
+
+ +
+
+ +

◆ is_ready()

+ +
+
+ + + + + +
+ + + + + + + +
bool pocketfft::detail::threading::latch::is_ready ()
+
+inline
+
+ +
+
+ +

◆ wait()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::latch::wait ()
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool-members.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool-members.html new file mode 100644 index 000000000..a0affd7d6 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool-members.html @@ -0,0 +1,96 @@ + + + + + + + +MLX: Member List + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
pocketfft::detail::threading::thread_pool Member List
+
+ + + + + diff --git a/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool.html b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool.html new file mode 100644 index 000000000..d09be49b4 --- /dev/null +++ b/docs/build/html/classpocketfft_1_1detail_1_1threading_1_1thread__pool.html @@ -0,0 +1,263 @@ + + + + + + + +MLX: pocketfft::detail::threading::thread_pool Class Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
pocketfft::detail::threading::thread_pool Class Reference
+
+
+ +

#include <pocketfft.h>

+ + + + + + + + + + + + + + +

+Public Member Functions

 thread_pool (size_t nthreads)
 
 thread_pool ()
 
 ~thread_pool ()
 
void submit (std::function< void()> work)
 
void shutdown ()
 
void restart ()
 
+

Constructor & Destructor Documentation

+ +

◆ thread_pool() [1/2]

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::thread_pool::thread_pool (size_t nthreads)
+
+inlineexplicit
+
+ +
+
+ +

◆ thread_pool() [2/2]

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::thread_pool::thread_pool ()
+
+inline
+
+ +
+
+ +

◆ ~thread_pool()

+ +
+
+ + + + + +
+ + + + + + + +
pocketfft::detail::threading::thread_pool::~thread_pool ()
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ restart()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::thread_pool::restart ()
+
+inline
+
+ +
+
+ +

◆ shutdown()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::thread_pool::shutdown ()
+
+inline
+
+ +
+
+ +

◆ submit()

+ +
+
+ + + + + +
+ + + + + + + +
void pocketfft::detail::threading::thread_pool::submit (std::function< void()> work)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/build/html/clipboard.js b/docs/build/html/clipboard.js new file mode 100644 index 000000000..42c1fb0e0 --- /dev/null +++ b/docs/build/html/clipboard.js @@ -0,0 +1,61 @@ +/** + +The code below is based on the Doxygen Awesome project, see +https://github.com/jothepro/doxygen-awesome-css + +MIT License + +Copyright (c) 2021 - 2022 jothepro + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +*/ + +let clipboard_title = "Copy to clipboard" +let clipboard_icon = `` +let clipboard_successIcon = `` +let clipboard_successDuration = 1000 + +$(function() { + if(navigator.clipboard) { + const fragments = document.getElementsByClassName("fragment") + for(const fragment of fragments) { + const clipboard_div = document.createElement("div") + clipboard_div.classList.add("clipboard") + clipboard_div.innerHTML = clipboard_icon + clipboard_div.title = clipboard_title + $(clipboard_div).click(function() { + const content = this.parentNode.cloneNode(true) + // filter out line number and folded fragments from file listings + content.querySelectorAll(".lineno, .ttc, .foldclosed").forEach((node) => { node.remove() }) + let text = content.textContent + // remove trailing newlines and trailing spaces from empty lines + text = text.replace(/^\s*\n/gm,'\n').replace(/\n*$/,'') + navigator.clipboard.writeText(text); + this.classList.add("success") + this.innerHTML = clipboard_successIcon + window.setTimeout(() => { // switch back to normal icon after timeout + this.classList.remove("success") + this.innerHTML = clipboard_icon + }, clipboard_successDuration); + }) + fragment.insertBefore(clipboard_div, fragment.firstChild) + } + } +}) diff --git a/docs/build/html/closed.png b/docs/build/html/closed.png new file mode 100644 index 0000000000000000000000000000000000000000..98cc2c909da37a6df914fbf67780eebd99c597f5 GIT binary patch literal 132 zcmeAS@N?(olHy`uVBq!ia0vp^oFL4>1|%O$WD@{V-kvUwAr*{o@8{^CZMh(5KoB^r_<4^zF@3)Cp&&t3hdujKf f*?bjBoY!V+E))@{xMcbjXe@)LtDnm{r-UW|*e5JT literal 0 HcmV?d00001 diff --git a/docs/build/html/common_2binary_8h.html b/docs/build/html/common_2binary_8h.html new file mode 100644 index 000000000..b8410e482 --- /dev/null +++ b/docs/build/html/common_2binary_8h.html @@ -0,0 +1,117 @@ + + + + + + + +MLX: mlx/backend/common/binary.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
binary.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+

Variable Documentation

+ +

◆ op

+ +
+
+ + + + +
Op op
+
+ +
+
+
+ + + + diff --git a/docs/build/html/common_2binary_8h_source.html b/docs/build/html/common_2binary_8h_source.html new file mode 100644 index 000000000..10e707cd3 --- /dev/null +++ b/docs/build/html/common_2binary_8h_source.html @@ -0,0 +1,764 @@ + + + + + + + +MLX: mlx/backend/common/binary.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
binary.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4#include "mlx/allocator.h"
+
5#include "mlx/array.h"
+ +
7
+
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12enum class BinaryOpType {
+
13 ScalarScalar,
+
14 ScalarVector,
+
15 VectorScalar,
+
16 VectorVector,
+
17 General,
+
18};
+
19
+
20BinaryOpType get_binary_op_type(const array& a, const array& b) {
+
21 BinaryOpType bopt;
+
22 if (a.data_size() == 1 && b.data_size() == 1) {
+
23 bopt = BinaryOpType::ScalarScalar;
+
24 } else if (a.data_size() == 1 && b.flags().contiguous) {
+
25 bopt = BinaryOpType::ScalarVector;
+
26 } else if (b.data_size() == 1 && a.flags().contiguous) {
+
27 bopt = BinaryOpType::VectorScalar;
+
28 } else if (
+
29 a.flags().row_contiguous && b.flags().row_contiguous ||
+
30 a.flags().col_contiguous && b.flags().col_contiguous) {
+
31 bopt = BinaryOpType::VectorVector;
+
32 } else {
+
33 bopt = BinaryOpType::General;
+
34 }
+
35 return bopt;
+
36}
+
37
+
38void set_binary_op_output_data(
+
39 const array& a,
+
40 const array& b,
+
41 array& out,
+
42 BinaryOpType bopt,
+
43 bool donate_with_move = false) {
+
44 switch (bopt) {
+
45 case BinaryOpType::ScalarScalar:
+
46 out.set_data(
+
47 allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
+
48 break;
+
49 case BinaryOpType::ScalarVector:
+
50 if (b.is_donatable() && b.itemsize() == out.itemsize()) {
+
51 if (donate_with_move) {
+
52 out.move_shared_buffer(b);
+
53 } else {
+
54 out.copy_shared_buffer(b);
+
55 }
+
56 } else {
+
57 out.set_data(
+
58 allocator::malloc_or_wait(b.data_size() * out.itemsize()),
+
59 b.data_size(),
+
60 b.strides(),
+
61 b.flags());
+
62 }
+
63 break;
+
64 case BinaryOpType::VectorScalar:
+
65 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
+
66 if (donate_with_move) {
+
67 out.move_shared_buffer(a);
+
68 } else {
+
69 out.copy_shared_buffer(a);
+
70 }
+
71 } else {
+
72 out.set_data(
+
73 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
+
74 a.data_size(),
+
75 a.strides(),
+
76 a.flags());
+
77 }
+
78 break;
+
79 case BinaryOpType::VectorVector:
+
80 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
+
81 if (donate_with_move) {
+
82 out.move_shared_buffer(a);
+
83 } else {
+
84 out.copy_shared_buffer(a);
+
85 }
+
86 } else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
+
87 if (donate_with_move) {
+
88 out.move_shared_buffer(b);
+
89 } else {
+
90 out.copy_shared_buffer(b);
+
91 }
+
92 } else {
+
93 out.set_data(
+
94 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
+
95 a.data_size(),
+
96 a.strides(),
+
97 a.flags());
+
98 }
+
99 break;
+
100 case BinaryOpType::General:
+
101 if (a.is_donatable() && a.flags().row_contiguous &&
+
102 a.itemsize() == out.itemsize() && a.size() == out.size()) {
+
103 if (donate_with_move) {
+
104 out.move_shared_buffer(a);
+
105 } else {
+
106 out.copy_shared_buffer(a);
+
107 }
+
108 } else if (
+
109 b.is_donatable() && b.flags().row_contiguous &&
+
110 b.itemsize() == out.itemsize() && b.size() == out.size()) {
+
111 if (donate_with_move) {
+
112 out.move_shared_buffer(b);
+
113 } else {
+
114 out.copy_shared_buffer(b);
+
115 }
+
116 } else {
+
117 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
118 }
+
119 break;
+
120 }
+
121}
+
122
+
123struct UseDefaultBinaryOp {
+
124 template <typename T, typename U>
+
125 void operator()(const T* a, const T* b, U* dst, int size) {
+
126 // Should we throw? This should normally never be called.
+
127 assert(false);
+
128 }
+
129
+
130 template <typename T, typename U>
+
131 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
132 // Should we throw? This should normally never be called.
+
133 assert(false);
+
134 }
+
135};
+
136
+
137template <typename T, typename U, typename Op>
+
138struct DefaultVectorScalar {
+
139 Op op;
+
140
+
141 DefaultVectorScalar(Op op_) : op(op_) {}
+
142
+
143 void operator()(const T* a, const T* b, U* dst, int size) {
+
144 T scalar = *b;
+
145 while (size-- > 0) {
+
146 *dst = op(*a, scalar);
+
147 dst++;
+
148 a++;
+
149 }
+
150 }
+
151
+
152 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
153 T scalar = *b;
+
154 while (size-- > 0) {
+
155 auto dst = op(*a, scalar);
+
156 *dst_a = dst.first;
+
157 *dst_b = dst.second;
+
158 dst_a++;
+
159 dst_b++;
+
160 a++;
+
161 }
+
162 }
+
163};
+
164
+
165template <typename T, typename U, typename Op>
+
166struct DefaultScalarVector {
+
167 Op op;
+
168
+
169 DefaultScalarVector(Op op_) : op(op_) {}
+
170
+
171 void operator()(const T* a, const T* b, U* dst, int size) {
+
172 T scalar = *a;
+
173 while (size-- > 0) {
+
174 *dst = op(scalar, *b);
+
175 dst++;
+
176 b++;
+
177 }
+
178 }
+
179
+
180 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
181 T scalar = *a;
+
182 while (size-- > 0) {
+
183 auto dst = op(scalar, *b);
+
184 *dst_a = dst.first;
+
185 *dst_b = dst.second;
+
186 dst_a++;
+
187 dst_b++;
+
188 b++;
+
189 }
+
190 }
+
191};
+
192
+
193template <typename T, typename U, typename Op>
+
194struct DefaultVectorVector {
+
195 Op op;
+
196
+
197 DefaultVectorVector(Op op_) : op(op_) {}
+
198
+
199 void operator()(const T* a, const T* b, U* dst, int size) {
+
200 while (size-- > 0) {
+
201 *dst = op(*a, *b);
+
202 dst++;
+
203 a++;
+
204 b++;
+
205 }
+
206 }
+
207
+
208 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
+
209 while (size-- > 0) {
+
210 auto dst = op(*a, *b);
+
211 *dst_a = dst.first;
+
212 *dst_b = dst.second;
+
213 dst_a++;
+
214 dst_b++;
+
215 a++;
+
216 b++;
+
217 }
+
218 }
+
219};
+
220
+
221template <typename T, typename U, typename Op>
+
222void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
+
223 const T* a_ptr = a.data<T>();
+
224 const T* b_ptr = b.data<T>();
+
225 U* dst = out.data<U>();
+
226 size_t a_idx = 0;
+
227 size_t b_idx = 0;
+
228 for (size_t i = 0; i < out.size(); ++i) {
+
229 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
230 a_idx += a.strides()[0];
+
231 b_idx += b.strides()[0];
+
232 }
+
233}
+
234
+
235template <typename T, typename U, typename Op>
+
236void binary_op_dims1(
+
237 const array& a,
+
238 const array& b,
+
239 array& out,
+
240 Op op,
+
241 int stride) {
+
242 const T* a_ptr = a.data<T>();
+
243 const T* b_ptr = b.data<T>();
+
244 U* dst = out.data<U>();
+
245 size_t a_idx = 0;
+
246 size_t b_idx = 0;
+
247 for (size_t i = 0; i < a.shape()[0]; i++) {
+
248 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
+
249 a_idx += a.strides()[0];
+
250 b_idx += b.strides()[0];
+
251 dst += stride;
+
252 }
+
253}
+
254
+
255template <typename T, typename U, typename Op>
+
256void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
+
257 const T* a_ptr = a.data<T>();
+
258 const T* b_ptr = b.data<T>();
+
259 U* dst = out.data<U>();
+
260 size_t a_idx = 0;
+
261 size_t b_idx = 0;
+
262 size_t out_idx = 0;
+
263 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
264 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
265 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
266 a_idx += a.strides()[1];
+
267 b_idx += b.strides()[1];
+
268 }
+
269 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
270 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
271 }
+
272}
+
273
+
274template <typename T, typename U, typename Op>
+
275void binary_op_dims2(
+
276 const array& a,
+
277 const array& b,
+
278 array& out,
+
279 Op op,
+
280 int stride) {
+
281 const T* a_ptr = a.data<T>();
+
282 const T* b_ptr = b.data<T>();
+
283 U* dst = out.data<U>();
+
284 size_t a_idx = 0;
+
285 size_t b_idx = 0;
+
286 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
287 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
288 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
+
289 a_idx += a.strides()[1];
+
290 b_idx += b.strides()[1];
+
291 dst += stride;
+
292 }
+
293 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
294 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
295 }
+
296}
+
297
+
298template <typename T, typename U, typename Op>
+
299void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
+
300 const T* a_ptr = a.data<T>();
+
301 const T* b_ptr = b.data<T>();
+
302 U* dst = out.data<U>();
+
303 size_t a_idx = 0;
+
304 size_t b_idx = 0;
+
305 size_t out_idx = 0;
+
306 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
307 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
308 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
309 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
310 a_idx += a.strides()[2];
+
311 b_idx += b.strides()[2];
+
312 }
+
313 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
314 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
315 }
+
316 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
317 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
318 }
+
319}
+
320
+
321template <typename T, typename U, typename Op>
+
322void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
+
323 const T* a_ptr = a.data<T>();
+
324 const T* b_ptr = b.data<T>();
+
325 U* dst = out.data<U>();
+
326 size_t a_idx = 0;
+
327 size_t b_idx = 0;
+
328 size_t out_idx = 0;
+
329 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
330 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
331 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
332 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
+
333 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
334 a_idx += a.strides()[3];
+
335 b_idx += b.strides()[3];
+
336 }
+
337 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
+
338 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
+
339 }
+
340 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
341 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
342 }
+
343 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
344 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
345 }
+
346}
+
347
+
348template <typename T, typename U, typename Op>
+
349void binary_op_dispatch_dims(
+
350 const array& a,
+
351 const array& b,
+
352 array& out,
+
353 Op op) {
+
354 switch (out.ndim()) {
+
355 case 1:
+
356 binary_op_dims1<T, U, Op>(a, b, out, op);
+
357 return;
+
358 case 2:
+
359 binary_op_dims2<T, U, Op>(a, b, out, op);
+
360 return;
+
361 case 3:
+
362 binary_op_dims3<T, U, Op>(a, b, out, op);
+
363 return;
+
364 case 4:
+
365 binary_op_dims4<T, U, Op>(a, b, out, op);
+
366 return;
+
367 }
+
368
+
369 const T* a_ptr = a.data<T>();
+
370 const T* b_ptr = b.data<T>();
+
371 U* dst = out.data<U>();
+
372 for (size_t i = 0; i < out.size(); i++) {
+
373 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
374 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
375 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
+
376 }
+
377}
+
378
+
379template <typename T, typename U, typename Op>
+
380void binary_op_dispatch_dims(
+
381 const array& a,
+
382 const array& b,
+
383 array& out,
+
384 Op op,
+
385 int dim,
+
386 int stride) {
+
387 // Number of dimensions to loop over for vectorized ops
+
388 switch (dim) {
+
389 case 1:
+
390 binary_op_dims1<T, U, Op>(a, b, out, op, stride);
+
391 return;
+
392 case 2:
+
393 binary_op_dims2<T, U, Op>(a, b, out, op, stride);
+
394 return;
+
395 }
+
396
+
397 const T* a_ptr = a.data<T>();
+
398 const T* b_ptr = b.data<T>();
+
399 U* dst = out.data<U>();
+
400 for (size_t i = 0; i < out.size(); i += stride) {
+
401 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
402 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
403 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
+
404 dst += stride;
+
405 }
+
406}
+
407
+
408template <
+
409 typename T,
+
410 typename U,
+
411 typename Op,
+
412 typename OpSV,
+
413 typename OpVS,
+
414 typename OpVV>
+
415void binary_op(
+
416 const array& a,
+
417 const array& b,
+
418 array& out,
+
419 Op op,
+
420 OpSV opsv,
+
421 OpVS opvs,
+
422 OpVV opvv) {
+
423 auto bopt = get_binary_op_type(a, b);
+
424 set_binary_op_output_data(a, b, out, bopt);
+
425
+
426 // The full computation is scalar scalar so call the base op once
+
427 if (bopt == BinaryOpType::ScalarScalar) {
+
428 *(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
+
429 return;
+
430 }
+
431
+
432 // The full computation is scalar vector so delegate to the op
+
433 if (bopt == BinaryOpType::ScalarVector) {
+
434 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
+
435 return;
+
436 }
+
437
+
438 // The full computation is vector scalar so delegate to the op
+
439 if (bopt == BinaryOpType::VectorScalar) {
+
440 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
+
441 return;
+
442 }
+
443
+
444 // The full computation is vector vector so delegate to the op
+
445 if (bopt == BinaryOpType::VectorVector) {
+
446 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
+
447 return;
+
448 }
+
449
+
450 // General computation so let's try to optimize
+
451
+
452 // Get the left-most dim such that the array is row contiguous after
+
453 auto& strides = out.strides();
+
454 auto leftmost_rc_dim = [&strides](const array& arr) {
+
455 int d = arr.ndim() - 1;
+
456 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
+
457 }
+
458 return d + 1;
+
459 };
+
460 auto a_rc_dim = leftmost_rc_dim(a);
+
461 auto b_rc_dim = leftmost_rc_dim(b);
+
462
+
463 // Get the left-most dim such that the array is a broadcasted "scalar" after
+
464 auto leftmost_s_dim = [](const array& arr) {
+
465 int d = arr.ndim() - 1;
+
466 for (; d >= 0 && arr.strides()[d] == 0; d--) {
+
467 }
+
468 return d + 1;
+
469 };
+
470 auto a_s_dim = leftmost_s_dim(a);
+
471 auto b_s_dim = leftmost_s_dim(b);
+
472
+
473 auto ndim = out.ndim();
+
474
+
475 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
+
476 int dim = ndim;
+
477 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
+
478 bopt = BinaryOpType::VectorVector;
+
479 dim = d;
+
480 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
+
481 // contiguous
+
482 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
+
483 bopt = BinaryOpType::VectorScalar;
+
484 dim = d;
+
485 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
+
486 // contiguous
+
487 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
+
488 bopt = BinaryOpType::ScalarVector;
+
489 dim = d;
+
490 }
+
491
+
492 // Can be sure dim > 0 since otherwise we would have used one of the fully
+
493 // contiguous methods above. Except for the case that the flags do not
+
494 // correspond to the underlying contiguity.
+
495 size_t stride;
+
496 if (dim == 0 || strides[dim - 1] < 16) {
+
497 stride = 1;
+
498 bopt = BinaryOpType::General;
+
499 dim = ndim;
+
500 } else {
+
501 stride = strides[dim - 1];
+
502 }
+
503
+
504 switch (bopt) {
+
505 case BinaryOpType::VectorVector:
+
506 binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
+
507 break;
+
508 case BinaryOpType::VectorScalar:
+
509 binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
+
510 break;
+
511 case BinaryOpType::ScalarVector:
+
512 binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
+
513 break;
+
514 default:
+
515 binary_op_dispatch_dims<T, U>(a, b, out, op);
+
516 break;
+
517 }
+
518}
+
519
+
520template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
+
521void binary_op(
+
522 const array& a,
+
523 const array& b,
+
524 array& out,
+
525 Op op,
+
526 OpSV opsv,
+
527 OpVS opvs,
+
528 OpVV opvv) {
+
529 // TODO: The following mess of constexpr evaluations can probably be achieved
+
530 // with template specializations and overloading. Would it be simpler?
+
531
+
532 if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
+
533 if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
534 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
535 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
+
536 binary_op<T, T>(
+
537 a,
+
538 b,
+
539 out,
+
540 op,
+
541 DefaultScalarVector<T, T, Op>(op),
+
542 DefaultVectorScalar<T, T, Op>(op),
+
543 DefaultVectorVector<T, T, Op>(op));
+
544 } else {
+
545 // opsv and opvs were UseDefaultBinaryOp
+
546 binary_op<T, T>(
+
547 a,
+
548 b,
+
549 out,
+
550 op,
+
551 DefaultScalarVector<T, T, Op>(op),
+
552 DefaultVectorScalar<T, T, Op>(op),
+
553 opvv);
+
554 }
+
555 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
556 // opsv and opvv were UseDefaultBinaryOp
+
557 binary_op<T, T>(
+
558 a,
+
559 b,
+
560 out,
+
561 op,
+
562 DefaultScalarVector<T, T, Op>(op),
+
563 opvs,
+
564 DefaultVectorVector<T, T, Op>(op));
+
565 } else {
+
566 // opsv was UseDefaultBinaryOp
+
567 binary_op<T, T>(
+
568 a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
+
569 }
+
570 } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
+
571 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
572 // opvs and opvv were UseDefaultBinaryOp
+
573 binary_op<T, T>(
+
574 a,
+
575 b,
+
576 out,
+
577 op,
+
578 opsv,
+
579 DefaultVectorScalar<T, T, Op>(op),
+
580 DefaultVectorVector<T, T, Op>(op));
+
581 } else {
+
582 // opvs was UseDefaultBinaryOp
+
583 binary_op<T, T>(
+
584 a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
+
585 }
+
586 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
+
587 // opvv was UseDefaultBinaryOp
+
588 binary_op<T, T>(
+
589 a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
+
590 } else {
+
591 // All ops provided
+
592 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
+
593 }
+
594}
+
595
+
596template <typename T, typename Op>
+
597void binary_op(const array& a, const array& b, array& out, Op op) {
+
598 DefaultScalarVector<T, T, Op> opsv(op);
+
599 DefaultVectorScalar<T, T, Op> opvs(op);
+
600 DefaultVectorVector<T, T, Op> opvv(op);
+
601 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
+
602}
+
603
+
604template <typename... Ops>
+
605void binary(const array& a, const array& b, array& out, Ops... ops) {
+
606 switch (out.dtype()) {
+
607 case bool_:
+
608 binary_op<bool>(a, b, out, ops...);
+
609 break;
+
610 case uint8:
+
611 binary_op<uint8_t>(a, b, out, ops...);
+
612 break;
+
613 case uint16:
+
614 binary_op<uint16_t>(a, b, out, ops...);
+
615 break;
+
616 case uint32:
+
617 binary_op<uint32_t>(a, b, out, ops...);
+
618 break;
+
619 case uint64:
+
620 binary_op<uint64_t>(a, b, out, ops...);
+
621 break;
+
622 case int8:
+
623 binary_op<int8_t>(a, b, out, ops...);
+
624 break;
+
625 case int16:
+
626 binary_op<int16_t>(a, b, out, ops...);
+
627 break;
+
628 case int32:
+
629 binary_op<int32_t>(a, b, out, ops...);
+
630 break;
+
631 case int64:
+
632 binary_op<int64_t>(a, b, out, ops...);
+
633 break;
+
634 case float16:
+
635 binary_op<float16_t>(a, b, out, ops...);
+
636 break;
+
637 case float32:
+
638 binary_op<float>(a, b, out, ops...);
+
639 break;
+
640 case bfloat16:
+
641 binary_op<bfloat16_t>(a, b, out, ops...);
+
642 break;
+
643 case complex64:
+
644 binary_op<complex64_t>(a, b, out, ops...);
+
645 break;
+
646 }
+
647}
+
648
+
649} // namespace
+
650
+
651} // namespace mlx::core
+ + + +
Op op
Definition binary.h:139
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+ +
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+
+ + + + diff --git a/docs/build/html/common_2compiled__preamble_8h.html b/docs/build/html/common_2compiled__preamble_8h.html new file mode 100644 index 000000000..824d31aae --- /dev/null +++ b/docs/build/html/common_2compiled__preamble_8h.html @@ -0,0 +1,118 @@ + + + + + + + +MLX: mlx/backend/common/compiled_preamble.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
compiled_preamble.h File Reference
+
+
+ +

Go to the source code of this file.

+ + + + +

+Functions

const char * get_kernel_preamble ()
 
+

Function Documentation

+ +

◆ get_kernel_preamble()

+ +
+
+ + + + + + + +
const char * get_kernel_preamble ()
+
+ +
+
+
+ + + + diff --git a/docs/build/html/common_2compiled__preamble_8h_source.html b/docs/build/html/common_2compiled__preamble_8h_source.html new file mode 100644 index 000000000..3fecd12e0 --- /dev/null +++ b/docs/build/html/common_2compiled__preamble_8h_source.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/backend/common/compiled_preamble.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compiled_preamble.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-24 Apple Inc.
+
2
+
3#pragma once
+
4
+
5// clang-format off
+ +
7#include "mlx/types/complex.h"
+ +
9// clang-format on
+
10
+
11const char* get_kernel_preamble();
+ +
const char * get_kernel_preamble()
+ + +
+ + + + diff --git a/docs/build/html/common_2copy_8h.html b/docs/build/html/common_2copy_8h.html new file mode 100644 index 000000000..2a90f4911 --- /dev/null +++ b/docs/build/html/common_2copy_8h.html @@ -0,0 +1,122 @@ + + + + + + + +MLX: mlx/backend/common/copy.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
copy.h File Reference
+
+
+
#include "mlx/array.h"
+#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Enumerations

enum class  mlx::core::CopyType { mlx::core::Scalar +, mlx::core::Vector +, mlx::core::General +, mlx::core::GeneralGeneral + }
 
+ + + + + + + + +

+Functions

void mlx::core::copy (const array &src, array &dst, CopyType ctype)
 
void mlx::core::copy_inplace (const array &src, array &dst, CopyType ctype)
 
template<typename stride_t >
void mlx::core::copy_inplace (const array &src, array &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)
 
+
+ + + + diff --git a/docs/build/html/common_2copy_8h_source.html b/docs/build/html/common_2copy_8h_source.html new file mode 100644 index 000000000..050fc5aa9 --- /dev/null +++ b/docs/build/html/common_2copy_8h_source.html @@ -0,0 +1,145 @@ + + + + + + + +MLX: mlx/backend/common/copy.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
copy.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/array.h"
+ +
7
+
8namespace mlx::core {
+
9
+
+
10enum class CopyType {
+
11 // Copy a raw scalar input into the full contiguous output
+
12 Scalar,
+
13
+
14 // Copy the raw input buffer contiguously into a raw output buffer of the same
+
15 // size
+
16 Vector,
+
17
+
18 // Copy the full virtual input to the full contiguous output
+
19 General,
+
20
+
21 // Copy the full virtual input to the full virtual output. We assume the
+
22 // input and output have the same shape.
+ +
24};
+
+
25
+
26void copy(const array& src, array& dst, CopyType ctype);
+
27void copy_inplace(const array& src, array& dst, CopyType ctype);
+
28
+
29template <typename stride_t>
+ +
31 const array& src,
+
32 array& dst,
+
33 const std::vector<int>& data_shape,
+
34 const std::vector<stride_t>& i_strides,
+
35 const std::vector<stride_t>& o_strides,
+
36 int64_t i_offset,
+
37 int64_t o_offset,
+
38 CopyType ctype);
+
39
+
40} // namespace mlx::core
+ + +
Definition array.h:20
+
Definition allocator.h:7
+
void copy(const array &src, array &dst, CopyType ctype)
+
void copy_inplace(const array &src, array &dst, CopyType ctype)
+
CopyType
Definition copy.h:10
+ + + + +
+ + + + diff --git a/docs/build/html/common_2reduce_8h.html b/docs/build/html/common_2reduce_8h.html new file mode 100644 index 000000000..951bd9ebd --- /dev/null +++ b/docs/build/html/common_2reduce_8h.html @@ -0,0 +1,136 @@ + + + + + + + +MLX: mlx/backend/common/reduce.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
reduce.h File Reference
+
+
+ +

Go to the source code of this file.

+ + + + +

+Classes

struct  mlx::core::ReductionPlan
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Enumerations

enum  mlx::core::ReductionOpType {
+  mlx::core::ContiguousAllReduce +, mlx::core::ContiguousReduce +, mlx::core::ContiguousStridedReduce +, mlx::core::GeneralContiguousReduce +,
+  mlx::core::GeneralStridedReduce +, mlx::core::GeneralReduce +
+ }
 
+

Variable Documentation

+ +

◆ op

+ +
+
+ + + + +
Op op
+
+ +
+
+
+ + + + diff --git a/docs/build/html/common_2reduce_8h_source.html b/docs/build/html/common_2reduce_8h_source.html new file mode 100644 index 000000000..813b7d68b --- /dev/null +++ b/docs/build/html/common_2reduce_8h_source.html @@ -0,0 +1,483 @@ + + + + + + + +MLX: mlx/backend/common/reduce.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
reduce.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+
7namespace mlx::core {
+
8
+
+ +
10 // Self-explanatory. Read everything and produce 1 output.
+ +
12
+
13 // The input is contiguous and the last axis is reduced
+
14 // N1xR1xN2xR2x...xNnxRn
+ +
16
+
17 // The input is contiguous and the last axis is not reduced
+
18 // R1xN1xR2xN2x...xRnxNn
+ +
20
+
21 // The input is not contiguous but the last axis is and it is reduced so we
+
22 // need to figure out the offsets but we can call the contiguous reduce after
+
23 // that.
+
24 // N3xR1xN1xR4x...xRn
+ +
26
+
27 // The input is not contiguous but the last reduction axis and the last axis
+
28 // are so we need to figure out the offset but we can call the strided reduce
+
29 // after that.
+ +
31
+
32 // The input is not contiguous after the reduction axis and it may contain
+
33 // 0-stride axes or transpositions. We could copy the strides and produce a
+
34 // transposed outcome or we can read the input out of order and write the
+
35 // output in order.
+ +
37};
+
+
38
+
+ + +
41 std::vector<int> shape;
+
42 std::vector<size_t> strides;
+
43
+
+ +
45 ReductionOpType type_,
+
46 std::vector<int> shape_,
+
47 std::vector<size_t> strides_)
+
48 : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
+
+ +
50};
+
+
51
+
52namespace {
+
53
+
54// Helper for the ndimensional strided loop
+
55// Should this be in utils?
+
56inline void nd_loop(
+
57 std::function<void(int)> callback,
+
58 const std::vector<int>& shape,
+
59 const std::vector<size_t>& strides) {
+
60 std::function<void(int, int)> loop_inner;
+
61 loop_inner = [&](int dim, int offset) {
+
62 if (dim < shape.size() - 1) {
+
63 int size = shape[dim];
+
64 size_t stride = strides[dim];
+
65 for (int i = 0; i < size; i++) {
+
66 loop_inner(dim + 1, offset + i * stride);
+
67 }
+
68 } else {
+
69 int size = shape[dim];
+
70 size_t stride = strides[dim];
+
71 for (int i = 0; i < size; i++) {
+
72 callback(offset + i * stride);
+
73 }
+
74 }
+
75 };
+
76 loop_inner(0, 0);
+
77}
+
78
+
79std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
+
80 const array& x,
+
81 const std::vector<int>& axes) {
+
82 std::vector<int> shape = x.shape();
+
83 std::vector<size_t> strides = x.strides();
+
84
+
85 for (int i = axes.size() - 1; i >= 0; i--) {
+
86 int a = axes[i];
+
87 shape.erase(shape.begin() + a);
+
88 strides.erase(strides.begin() + a);
+
89 }
+
90
+
91 return std::make_pair(shape, strides);
+
92}
+
93
+
94template <typename T, typename U, typename Op>
+
95struct DefaultStridedReduce {
+
96 Op op;
+
97
+
98 DefaultStridedReduce(Op op_) : op(op_) {}
+
99
+
100 void operator()(const T* x, U* accumulator, int size, size_t stride) {
+
101 for (int i = 0; i < size; i++) {
+
102 U* moving_accumulator = accumulator;
+
103 for (int j = 0; j < stride; j++) {
+
104 op(moving_accumulator, *x);
+
105 moving_accumulator++;
+
106 x++;
+
107 }
+
108 }
+
109 }
+
110};
+
111
+
112template <typename T, typename U, typename Op>
+
113struct DefaultContiguousReduce {
+
114 Op op;
+
115
+
116 DefaultContiguousReduce(Op op_) : op(op_) {}
+
117
+
118 void operator()(const T* x, U* accumulator, int size) {
+
119 while (size-- > 0) {
+
120 op(accumulator, *x);
+
121 x++;
+
122 }
+
123 }
+
124};
+
125
+
126ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
+
127 // The data is all there and we are reducing over everything
+
128 if (x.size() == x.data_size() && axes.size() == x.ndim() &&
+
129 x.flags().contiguous) {
+
130 return ContiguousAllReduce;
+
131 }
+
132
+
133 // Row contiguous input so the output is row contiguous
+
134 if (x.flags().row_contiguous) {
+
135 // Merge consecutive axes
+
136 std::vector<int> shape = {x.shape(axes[0])};
+
137 std::vector<size_t> strides = {x.strides()[axes[0]]};
+
138 for (int i = 1; i < axes.size(); i++) {
+
139 if (axes[i] - 1 == axes[i - 1]) {
+
140 shape.back() *= x.shape(axes[i]);
+
141 strides.back() = x.strides()[axes[i]];
+
142 } else {
+
143 shape.push_back(x.shape(axes[i]));
+
144 strides.push_back(x.strides()[axes[i]]);
+
145 }
+
146 }
+
147
+
148 if (strides.back() == 1) {
+
149 return ReductionPlan(ContiguousReduce, shape, strides);
+
150 } else if (strides.back() > 1) {
+
151 return ReductionPlan(ContiguousStridedReduce, shape, strides);
+
152 }
+
153 }
+
154
+
155 // Let's check if we can optimize our access patterns
+
156 //
+
157 // 1. We have a reduction axis with stride 1. Simply call
+
158 // GeneralContiguousReduce and be done with it.
+
159 // 2. We have transpositions and we are not reducing over the axis with
+
160 // stride 1. However, we are reducing over an axis where everything is
+
161 // contiguous in memory to the right of that axis. We can call strided
+
162 // reduce and be done with it.
+
163 // 2. We have weird transpositions and expands. Copy the strides to the
+
164 // output, then call strided reduce.
+
165
+
166 // Sort reduction axes by stride in order to merge them and figure out if we
+
167 // have a contiguous reduction.
+
168 std::vector<std::pair<int, size_t>> reductions;
+
169 for (auto a : axes) {
+
170 reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
+
171 }
+
172 std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
+
173 return a.second > b.second;
+
174 });
+
175 // Extract the two smallest and try to merge them in case the contiguous
+
176 // reduction can be bigger than just the last axis.
+
177 for (int i = reductions.size() - 1; i >= 1; i--) {
+
178 auto a = reductions[i];
+
179 auto b = reductions[i - 1];
+
180
+
181 // b.stride = a.shape * a.stride then a and b are contiguous
+
182 if (b.second == a.first * a.second) {
+
183 reductions.erase(reductions.begin() + i);
+
184 reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
+
185 }
+
186 }
+
187
+
188 std::vector<int> shape;
+
189 std::vector<size_t> strides;
+
190 for (auto r : reductions) {
+
191 shape.push_back(r.first);
+
192 strides.push_back(r.second);
+
193 }
+
194
+
195 // We can call the contiguous reduction op for every weird way the input is
+
196 // structured in the rest of the axes.
+
197 if (strides.back() == 1) {
+
198 return ReductionPlan(GeneralContiguousReduce, shape, strides);
+
199 }
+
200
+
201 // Delegate to the general strided reduction op if the axes after
+
202 // strides.back() are contiguous.
+
203 if (strides.back() > 1) {
+
204 int size = 1;
+
205 for (int i = x.ndim() - 1; i >= 0; i--) {
+
206 if (axes.back() == i) {
+
207 continue;
+
208 }
+
209 if (x.strides()[i] != size) {
+
210 break;
+
211 }
+
212 size *= x.shape(i);
+
213 }
+
214 if (size >= strides.back()) {
+
215 return ReductionPlan(GeneralStridedReduce, shape, strides);
+
216 }
+
217 }
+
218
+
219 return ReductionPlan(GeneralReduce, shape, strides);
+
220}
+
221
+
222template <typename T, typename U, typename OpS, typename OpC, typename Op>
+
223void reduction_op(
+
224 const array& x,
+
225 array& out,
+
226 const std::vector<int>& axes,
+
227 U init,
+
228 OpS ops,
+
229 OpC opc,
+
230 Op op) {
+
231 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
232 ReductionPlan plan = get_reduction_plan(x, axes);
+
233
+
234 if (plan.type == ContiguousAllReduce) {
+
235 U* out_ptr = out.data<U>();
+
236 *out_ptr = init;
+
237 opc(x.data<T>(), out_ptr, x.size());
+
238 return;
+
239 }
+
240
+
241 std::vector<int> shape;
+
242 std::vector<size_t> strides;
+
243
+
244 if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
+
245 int reduction_size = plan.shape[0];
+
246 const T* x_ptr = x.data<T>();
+
247 U* out_ptr = out.data<U>();
+
248 for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
+
249 *out_ptr = init;
+
250 opc(x_ptr, out_ptr, reduction_size);
+
251 }
+
252 return;
+
253 }
+
254
+
255 if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
+
256 int reduction_size = plan.shape.back();
+
257 plan.shape.pop_back();
+
258 plan.strides.pop_back();
+
259 const T* x_ptr = x.data<T>();
+
260 U* out_ptr = out.data<U>();
+
261 // Unrolling the following loop (and implementing it in order for
+
262 // ContiguousReduce) should hold extra performance boost.
+
263 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
+
264 if (plan.shape.size() == 0) {
+
265 for (int i = 0; i < out.size(); i++, out_ptr++) {
+
266 int offset = elem_to_loc(i, shape, strides);
+
267 *out_ptr = init;
+
268 opc(x_ptr + offset, out_ptr, reduction_size);
+
269 }
+
270 } else {
+
271 for (int i = 0; i < out.size(); i++, out_ptr++) {
+
272 int offset = elem_to_loc(i, shape, strides);
+
273 *out_ptr = init;
+
274 nd_loop(
+
275 [&](int extra_offset) {
+
276 opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
+
277 },
+
278 plan.shape,
+
279 plan.strides);
+
280 }
+
281 }
+
282 return;
+
283 }
+
284
+
285 if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
+
286 int reduction_size = plan.shape.back();
+
287 size_t reduction_stride = plan.strides.back();
+
288 plan.shape.pop_back();
+
289 plan.strides.pop_back();
+
290 const T* x_ptr = x.data<T>();
+
291 U* out_ptr = out.data<U>();
+
292 for (int i = 0; i < out.size(); i += reduction_stride) {
+
293 std::fill_n(out_ptr, reduction_stride, init);
+
294 ops(x_ptr, out_ptr, reduction_size, reduction_stride);
+
295 x_ptr += reduction_stride * reduction_size;
+
296 out_ptr += reduction_stride;
+
297 }
+
298 return;
+
299 }
+
300
+
301 if (plan.type == GeneralStridedReduce ||
+
302 plan.type == ContiguousStridedReduce) {
+
303 int reduction_size = plan.shape.back();
+
304 size_t reduction_stride = plan.strides.back();
+
305 plan.shape.pop_back();
+
306 plan.strides.pop_back();
+
307 const T* x_ptr = x.data<T>();
+
308 U* out_ptr = out.data<U>();
+
309 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
+
310 if (plan.shape.size() == 0) {
+
311 for (int i = 0; i < out.size(); i += reduction_stride) {
+
312 int offset = elem_to_loc(i, shape, strides);
+
313 std::fill_n(out_ptr, reduction_stride, init);
+
314 ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
+
315 out_ptr += reduction_stride;
+
316 }
+
317 } else {
+
318 for (int i = 0; i < out.size(); i += reduction_stride) {
+
319 int offset = elem_to_loc(i, shape, strides);
+
320 std::fill_n(out_ptr, reduction_stride, init);
+
321 nd_loop(
+
322 [&](int extra_offset) {
+
323 ops(x_ptr + offset + extra_offset,
+
324 out_ptr,
+
325 reduction_size,
+
326 reduction_stride);
+
327 },
+
328 plan.shape,
+
329 plan.strides);
+
330 out_ptr += reduction_stride;
+
331 }
+
332 }
+
333 return;
+
334 }
+
335
+
336 if (plan.type == GeneralReduce) {
+
337 const T* x_ptr = x.data<T>();
+
338 U* out_ptr = out.data<U>();
+
339 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
+
340 for (int i = 0; i < out.size(); i++, out_ptr++) {
+
341 int offset = elem_to_loc(i, shape, strides);
+
342 U val = init;
+
343 nd_loop(
+
344 [&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
+
345 plan.shape,
+
346 plan.strides);
+
347 *out_ptr = val;
+
348 }
+
349 }
+
350}
+
351
+
352template <typename T, typename U, typename Op>
+
353void reduction_op(
+
354 const array& x,
+
355 array& out,
+
356 const std::vector<int>& axes,
+
357 U init,
+
358 Op op) {
+
359 DefaultStridedReduce<T, U, Op> ops(op);
+
360 DefaultContiguousReduce<T, U, Op> opc(op);
+
361 reduction_op<T, U>(x, out, axes, init, ops, opc, op);
+
362}
+
363
+
364} // namespace
+
365
+
366} // namespace mlx::core
+ +
Op op
Definition binary.h:139
+
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
ReductionOpType
Definition reduce.h:9
+
@ GeneralReduce
Definition reduce.h:36
+
@ GeneralContiguousReduce
Definition reduce.h:25
+
@ ContiguousStridedReduce
Definition reduce.h:19
+
@ ContiguousReduce
Definition reduce.h:15
+
@ GeneralStridedReduce
Definition reduce.h:30
+
@ ContiguousAllReduce
Definition reduce.h:11
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
Definition reduce.h:39
+
ReductionOpType type
Definition reduce.h:40
+
ReductionPlan(ReductionOpType type_, std::vector< int > shape_, std::vector< size_t > strides_)
Definition reduce.h:44
+
std::vector< int > shape
Definition reduce.h:41
+
std::vector< size_t > strides
Definition reduce.h:42
+
ReductionPlan(ReductionOpType type_)
Definition reduce.h:49
+
+ + + + diff --git a/docs/build/html/common_2ternary_8h.html b/docs/build/html/common_2ternary_8h.html new file mode 100644 index 000000000..2bdd394cc --- /dev/null +++ b/docs/build/html/common_2ternary_8h.html @@ -0,0 +1,103 @@ + + + + + + + +MLX: mlx/backend/common/ternary.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
ternary.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+#include "mlx/backend/common/ops.h"
+#include "mlx/backend/common/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/common_2ternary_8h_source.html b/docs/build/html/common_2ternary_8h_source.html new file mode 100644 index 000000000..edce1f51f --- /dev/null +++ b/docs/build/html/common_2ternary_8h_source.html @@ -0,0 +1,327 @@ + + + + + + + +MLX: mlx/backend/common/ternary.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
ternary.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4#include "mlx/allocator.h"
+
5#include "mlx/array.h"
+ + +
8namespace mlx::core {
+
9
+
10namespace {
+
11
+
12// TODO: Add support for more combinations of input types.
+
13enum class TernaryOpType {
+
14 ScalarScalarScalar,
+
15 General,
+
16};
+
17
+
18TernaryOpType
+
19get_ternary_op_type(const array& a, const array& b, const array& c) {
+
20 TernaryOpType topt;
+
21 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
+
22 topt = TernaryOpType::ScalarScalarScalar;
+
23 } else {
+
24 topt = TernaryOpType::General;
+
25 }
+
26 return topt;
+
27}
+
28
+
29void set_ternary_op_output_data(
+
30 const array& a,
+
31 const array& b,
+
32 const array& c,
+
33 array& out,
+
34 TernaryOpType topt,
+
35 bool donate_with_move = false) {
+
36 switch (topt) {
+
37 case TernaryOpType::ScalarScalarScalar:
+
38 out.set_data(
+
39 allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
+
40 break;
+
41 case TernaryOpType::General:
+
42 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
43 break;
+
44 }
+
45}
+
46
+
47template <typename T1, typename T2, typename T3, typename U, typename Op>
+
48void ternary_op_dims1(
+
49 const array& a,
+
50 const array& b,
+
51 const array& c,
+
52 array& out,
+
53 Op op) {
+
54 const T1* a_ptr = a.data<T1>();
+
55 const T2* b_ptr = b.data<T2>();
+
56 const T3* c_ptr = c.data<T3>();
+
57
+
58 U* dst = out.data<U>();
+
59 size_t a_idx = 0;
+
60 size_t b_idx = 0;
+
61 size_t c_idx = 0;
+
62 for (size_t i = 0; i < out.size(); ++i) {
+
63 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
64 a_idx += a.strides()[0];
+
65 b_idx += b.strides()[0];
+
66 c_idx += c.strides()[0];
+
67 }
+
68}
+
69
+
70template <typename T1, typename T2, typename T3, typename U, typename Op>
+
71void ternary_op_dims2(
+
72 const array& a,
+
73 const array& b,
+
74 const array& c,
+
75 array& out,
+
76 Op op) {
+
77 const T1* a_ptr = a.data<T1>();
+
78 const T2* b_ptr = b.data<T2>();
+
79 const T3* c_ptr = c.data<T3>();
+
80
+
81 U* dst = out.data<U>();
+
82 size_t a_idx = 0;
+
83 size_t b_idx = 0;
+
84 size_t c_idx = 0;
+
85 size_t out_idx = 0;
+
86 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
87 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
88 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
89 a_idx += a.strides()[1];
+
90 b_idx += b.strides()[1];
+
91 c_idx += c.strides()[1];
+
92 }
+
93 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
94 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
95 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
+
96 }
+
97}
+
98
+
99template <typename T1, typename T2, typename T3, typename U, typename Op>
+
100void ternary_op_dims3(
+
101 const array& a,
+
102 const array& b,
+
103 const array& c,
+
104 array& out,
+
105 Op op) {
+
106 const T1* a_ptr = a.data<T1>();
+
107 const T2* b_ptr = b.data<T2>();
+
108 const T3* c_ptr = c.data<T3>();
+
109 U* dst = out.data<U>();
+
110 size_t a_idx = 0;
+
111 size_t b_idx = 0;
+
112 size_t c_idx = 0;
+
113 size_t out_idx = 0;
+
114 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
115 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
116 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
117 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
118 a_idx += a.strides()[2];
+
119 b_idx += b.strides()[2];
+
120 c_idx += c.strides()[2];
+
121 }
+
122 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
123 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
124 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
+
125 }
+
126 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
127 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
128 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
+
129 }
+
130}
+
131
+
132template <typename T1, typename T2, typename T3, typename U, typename Op>
+
133void ternary_op_dims4(
+
134 const array& a,
+
135 const array& b,
+
136 const array& c,
+
137 array& out,
+
138 Op op) {
+
139 const T1* a_ptr = a.data<T1>();
+
140 const T2* b_ptr = b.data<T2>();
+
141 const T3* c_ptr = c.data<T3>();
+
142
+
143 U* dst = out.data<U>();
+
144 size_t a_idx = 0;
+
145 size_t b_idx = 0;
+
146 size_t c_idx = 0;
+
147 size_t out_idx = 0;
+
148 for (size_t i = 0; i < a.shape()[0]; ++i) {
+
149 for (size_t j = 0; j < a.shape()[1]; ++j) {
+
150 for (size_t k = 0; k < a.shape()[2]; ++k) {
+
151 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
+
152 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
153 a_idx += a.strides()[3];
+
154 b_idx += b.strides()[3];
+
155 c_idx += c.strides()[3];
+
156 }
+
157 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
+
158 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
+
159 c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
+
160 }
+
161 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
+
162 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
+
163 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
+
164 }
+
165 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
+
166 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
+
167 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
+
168 }
+
169}
+
170
+
171template <typename T1, typename T2, typename T3, typename U, typename Op>
+
172void ternary_op_dispatch_dims(
+
173 const array& a,
+
174 const array& b,
+
175 const array& c,
+
176 array& out,
+
177 Op op) {
+
178 switch (out.ndim()) {
+
179 case 1:
+
180 ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
+
181 return;
+
182 case 2:
+
183 ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
+
184 return;
+
185 case 3:
+
186 ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
+
187 return;
+
188 case 4:
+
189 ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
+
190 return;
+
191 }
+
192
+
193 const T1* a_ptr = a.data<T1>();
+
194 const T2* b_ptr = b.data<T2>();
+
195 const T3* c_ptr = c.data<T3>();
+
196 U* dst = out.data<U>();
+
197 for (size_t i = 0; i < out.size(); i++) {
+
198 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
199 int b_idx = elem_to_loc(i, b.shape(), b.strides());
+
200 int c_idx = elem_to_loc(i, c.shape(), c.strides());
+
201 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
+
202 }
+
203}
+
204
+
205template <typename T1, typename T2, typename T3, typename U, typename Op>
+
206void ternary_op(
+
207 const array& a,
+
208 const array& b,
+
209 const array& c,
+
210 array& out,
+
211 Op op) {
+
212 TernaryOpType topt = get_ternary_op_type(a, b, c);
+
213 set_ternary_op_output_data(a, b, c, out, topt);
+
214
+
215 // The full computation is scalar-scalar-scalar so we call the base op once.
+
216 if (topt == TernaryOpType::ScalarScalarScalar) {
+
217 *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
+
218 return;
+
219 }
+
220
+
221 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
+
222}
+
223
+
224} // namespace
+
225
+
226} // namespace mlx::core
+ + + + +
Op op
Definition binary.h:139
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+ +
+ + + + diff --git a/docs/build/html/common_2unary_8h.html b/docs/build/html/common_2unary_8h.html new file mode 100644 index 000000000..c0017b1c4 --- /dev/null +++ b/docs/build/html/common_2unary_8h.html @@ -0,0 +1,103 @@ + + + + + + + +MLX: mlx/backend/common/unary.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
unary.h File Reference
+
+
+
#include "mlx/allocator.h"
+#include "mlx/array.h"
+#include "mlx/backend/common/utils.h"
+#include "mlx/utils.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+
+ + + + diff --git a/docs/build/html/common_2unary_8h_source.html b/docs/build/html/common_2unary_8h_source.html new file mode 100644 index 000000000..1fa0f5e15 --- /dev/null +++ b/docs/build/html/common_2unary_8h_source.html @@ -0,0 +1,229 @@ + + + + + + + +MLX: mlx/backend/common/unary.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
unary.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/allocator.h"
+
6#include "mlx/array.h"
+ +
8#include "mlx/utils.h"
+
9
+
10namespace mlx::core {
+
11
+
12namespace {
+
13
+
14void set_unary_output_data(const array& in, array& out) {
+
15 if (in.is_donatable() && in.itemsize() == out.itemsize()) {
+
16 out.copy_shared_buffer(in);
+
17 } else {
+
18 auto size = in.data_size();
+
19 out.set_data(
+
20 allocator::malloc_or_wait(size * out.itemsize()),
+
21 size,
+
22 in.strides(),
+
23 in.flags());
+
24 }
+
25}
+
26
+
27template <typename T, typename Op>
+
28void unary_op(const array& a, array& out, Op op) {
+
29 const T* a_ptr = a.data<T>();
+
30 if (a.flags().contiguous) {
+
31 set_unary_output_data(a, out);
+
32 T* dst = out.data<T>();
+
33 for (size_t i = 0; i < a.data_size(); ++i) {
+
34 dst[i] = op(a_ptr[i]);
+
35 }
+
36 } else {
+
37 out.set_data(allocator::malloc_or_wait(out.nbytes()));
+
38 T* dst = out.data<T>();
+
39 for (size_t i = 0; i < out.size(); ++i) {
+
40 // TODO this is super inefficient, need to fix.
+
41 int a_idx = elem_to_loc(i, a.shape(), a.strides());
+
42 dst[i] = op(a_ptr[a_idx]);
+
43 }
+
44 }
+
45}
+
46
+
47template <typename Op>
+
48void unary(const array& a, array& out, Op op) {
+
49 switch (out.dtype()) {
+
50 case bool_:
+
51 unary_op<bool>(a, out, op);
+
52 break;
+
53 case uint8:
+
54 unary_op<uint8_t>(a, out, op);
+
55 break;
+
56 case uint16:
+
57 unary_op<uint16_t>(a, out, op);
+
58 break;
+
59 case uint32:
+
60 unary_op<uint32_t>(a, out, op);
+
61 break;
+
62 case uint64:
+
63 unary_op<uint64_t>(a, out, op);
+
64 break;
+
65 case int8:
+
66 unary_op<int8_t>(a, out, op);
+
67 break;
+
68 case int16:
+
69 unary_op<int16_t>(a, out, op);
+
70 break;
+
71 case int32:
+
72 unary_op<int32_t>(a, out, op);
+
73 break;
+
74 case int64:
+
75 unary_op<int64_t>(a, out, op);
+
76 break;
+
77 case float16:
+
78 unary_op<float16_t>(a, out, op);
+
79 break;
+
80 case float32:
+
81 unary_op<float>(a, out, op);
+
82 break;
+
83 case bfloat16:
+
84 unary_op<bfloat16_t>(a, out, op);
+
85 break;
+
86 case complex64:
+
87 unary_op<complex64_t>(a, out, op);
+
88 break;
+
89 }
+
90}
+
91
+
92template <typename Op>
+
93void unary_fp(const array& a, array& out, Op op) {
+
94 switch (out.dtype()) {
+
95 case bfloat16:
+
96 unary_op<bfloat16_t>(a, out, op);
+
97 break;
+
98 case float16:
+
99 unary_op<float16_t>(a, out, op);
+
100 break;
+
101 case float32:
+
102 unary_op<float>(a, out, op);
+
103 break;
+
104 case complex64:
+
105 unary_op<complex64_t>(a, out, op);
+
106 break;
+
107 default:
+
108 std::ostringstream err;
+
109 err << "[unary_fp] Does not support " << out.dtype();
+
110 throw std::runtime_error(err.str());
+
111 }
+
112}
+
113
+
114} // namespace
+
115
+
116} // namespace mlx::core
+ + + +
Op op
Definition binary.h:139
+
Buffer malloc_or_wait(size_t size)
+
Definition allocator.h:7
+
constexpr Dtype bool_
Definition dtype.h:60
+
constexpr Dtype uint64
Definition dtype.h:65
+
constexpr Dtype uint16
Definition dtype.h:63
+
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
+
constexpr Dtype bfloat16
Definition dtype.h:74
+
constexpr Dtype int32
Definition dtype.h:69
+
constexpr Dtype float32
Definition dtype.h:73
+
constexpr Dtype int16
Definition dtype.h:68
+
constexpr Dtype int8
Definition dtype.h:67
+
constexpr Dtype int64
Definition dtype.h:70
+
constexpr Dtype uint8
Definition dtype.h:62
+
constexpr Dtype float16
Definition dtype.h:72
+
constexpr Dtype uint32
Definition dtype.h:64
+
constexpr Dtype complex64
Definition dtype.h:75
+ +
+ + + + diff --git a/docs/build/html/compile_8h.html b/docs/build/html/compile_8h.html new file mode 100644 index 000000000..3689b75f0 --- /dev/null +++ b/docs/build/html/compile_8h.html @@ -0,0 +1,130 @@ + + + + + + + +MLX: mlx/compile.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
compile.h File Reference
+
+
+
#include "mlx/array.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + +

+Enumerations

enum class  mlx::core::CompileMode { mlx::core::disabled +, mlx::core::no_simplify +, mlx::core::no_fuse +, mlx::core::enabled + }
 
+ + + + + + + + + + +

+Functions

void mlx::core::disable_compile ()
 Globally disable compilation.
 
void mlx::core::enable_compile ()
 Globally enable compilation.
 
void mlx::core::set_compile_mode (CompileMode mode)
 Set the compiler mode to the given value.
 
+ + + + +

+Variables

std::function< std::vector< array >(const std::vector< array > &) mlx::core::compile )(const std::function< std::vector< array >(const std::vector< array > &)> &fun, bool shapeless=false)
 Compile takes a function and returns a compiled function.
 
+
+ + + + diff --git a/docs/build/html/compile_8h_source.html b/docs/build/html/compile_8h_source.html new file mode 100644 index 000000000..8960d1ddc --- /dev/null +++ b/docs/build/html/compile_8h_source.html @@ -0,0 +1,123 @@ + + + + + + + +MLX: mlx/compile.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compile.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/array.h"
+
6
+
7namespace mlx::core {
+
8
+ +
10
+
12std::function<std::vector<array>(const std::vector<array>&)> compile(
+
13 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
+
14 bool shapeless = false);
+
15
+ +
21
+ +
26
+ +
29} // namespace mlx::core
+ +
Definition allocator.h:7
+
void enable_compile()
Globally enable compilation.
+
void set_compile_mode(CompileMode mode)
Set the compiler mode to the given value.
+
void disable_compile()
Globally disable compilation.
+
std::function< std::vector< array >(const std::vector< array > &) compile)(const std::function< std::vector< array >(const std::vector< array > &)> &fun, bool shapeless=false)
Compile takes a function and returns a compiled function.
+
CompileMode
Definition compile.h:9
+ + + + +
+ + + + diff --git a/docs/build/html/compile__impl_8h.html b/docs/build/html/compile__impl_8h.html new file mode 100644 index 000000000..cfe23b4a1 --- /dev/null +++ b/docs/build/html/compile__impl_8h.html @@ -0,0 +1,108 @@ + + + + + + + +MLX: mlx/compile_impl.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
compile_impl.h File Reference
+
+
+
#include "mlx/device.h"
+
+

Go to the source code of this file.

+ + + + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::detail
 
+ + + +

+Functions

bool mlx::core::detail::compile_available_for_device (const Device &device)
 
+
+ + + + diff --git a/docs/build/html/compile__impl_8h_source.html b/docs/build/html/compile__impl_8h_source.html new file mode 100644 index 000000000..83d84e37d --- /dev/null +++ b/docs/build/html/compile__impl_8h_source.html @@ -0,0 +1,107 @@ + + + + + + + +MLX: mlx/compile_impl.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compile_impl.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5#include "mlx/device.h"
+
6
+
7namespace mlx::core::detail {
+
8
+ +
10
+
11}
+ +
Definition ops.h:8
+
bool compile_available_for_device(const Device &device)
+
Definition device.h:7
+
+ + + + diff --git a/docs/build/html/compiled_8h.html b/docs/build/html/compiled_8h.html new file mode 100644 index 000000000..aae146e42 --- /dev/null +++ b/docs/build/html/compiled_8h.html @@ -0,0 +1,131 @@ + + + + + + + +MLX: mlx/backend/common/compiled.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
compiled.h File Reference
+
+
+
#include <iomanip>
+#include <sstream>
+#include <unordered_set>
+#include "mlx/array.h"
+#include "mlx/primitives.h"
+
+

Go to the source code of this file.

+ + + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::core
 
+ + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

bool mlx::core::is_static_cast (const Primitive &p)
 
std::string mlx::core::build_lib_name (const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &tape, const std::unordered_set< uintptr_t > &constant_ids)
 
std::string mlx::core::get_type_string (Dtype d)
 
template<typename T >
void mlx::core::print_float_constant (std::ostream &os, const array &x)
 
template<typename T >
void mlx::core::print_int_constant (std::ostream &os, const array &x)
 
template<typename T >
void mlx::core::print_complex_constant (std::ostream &os, const array &x)
 
void mlx::core::print_constant (std::ostream &os, const array &x)
 
bool mlx::core::is_scalar (const array &x)
 
bool mlx::core::compiled_check_contiguity (const std::vector< array > &inputs, const std::vector< int > &shape)
 
void mlx::core::compiled_allocate_outputs (const std::vector< array > &inputs, std::vector< array > &outputs, const std::vector< array > &inputs_, const std::unordered_set< uintptr_t > &constant_ids_, bool contiguous, bool move_buffers=false)
 
+
+ + + + diff --git a/docs/build/html/compiled_8h_source.html b/docs/build/html/compiled_8h_source.html new file mode 100644 index 000000000..6e37a563e --- /dev/null +++ b/docs/build/html/compiled_8h_source.html @@ -0,0 +1,195 @@ + + + + + + + +MLX: mlx/backend/common/compiled.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
compiled.h
+
+
+Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
+
2#pragma once
+
3
+
4#include <iomanip>
+
5#include <sstream>
+
6#include <unordered_set>
+
7
+
8#include "mlx/array.h"
+
9#include "mlx/primitives.h"
+
10
+
11namespace mlx::core {
+
12
+
+
13inline bool is_static_cast(const Primitive& p) {
+
14 return (
+
15 typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
+
16 typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
+
17}
+
+
18
+
19std::string build_lib_name(
+
20 const std::vector<array>& inputs,
+
21 const std::vector<array>& outputs,
+
22 const std::vector<array>& tape,
+
23 const std::unordered_set<uintptr_t>& constant_ids);
+
24
+
25std::string get_type_string(Dtype d);
+
26
+
27template <typename T>
+
+
28void print_float_constant(std::ostream& os, const array& x) {
+
29 auto old_precision = os.precision();
+
30 os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
+
31 << x.item<T>() << std::setprecision(old_precision);
+
32}
+
+
33
+
34template <typename T>
+
+
35void print_int_constant(std::ostream& os, const array& x) {
+
36 os << x.item<T>();
+
37}
+
+
38
+
39template <typename T>
+
+
40void print_complex_constant(std::ostream& os, const array& x) {
+
41 auto old_precision = os.precision();
+
42 T constant = x.item<T>();
+
43
+
44 os << get_type_string(x.dtype()) << "("
+
45 << std::setprecision(std::numeric_limits<float>::digits10 + 1)
+
46 << constant.real() << ", " << constant.imag() << ")"
+
47 << std::setprecision(old_precision);
+
48}
+
+
49
+
50void print_constant(std::ostream& os, const array& x);
+
51
+
+
52inline bool is_scalar(const array& x) {
+
53 return x.ndim() == 0;
+
54}
+
+
55
+
56// Check if we can use a contiguous operation given inputs and the output shape
+ +
58 const std::vector<array>& inputs,
+
59 const std::vector<int>& shape);
+
60
+
61// Allocate space for the outputs possibly with input donation
+ +
63 const std::vector<array>& inputs,
+
64 std::vector<array>& outputs,
+
65 const std::vector<array>& inputs_,
+
66 const std::unordered_set<uintptr_t>& constant_ids_,
+
67 bool contiguous,
+
68 bool move_buffers = false);
+
69
+
70} // namespace mlx::core
+ +
Definition primitives.h:416
+
Definition primitives.h:525
+
Definition primitives.h:680
+
Definition primitives.h:48
+
Definition primitives.h:1919
+
Definition array.h:20
+
size_t ndim() const
The number of dimensions of the array.
Definition array.h:94
+
T item()
Get the value from a scalar array.
Definition array.h:489
+
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
+
Definition allocator.h:7
+
void print_complex_constant(std::ostream &os, const array &x)
Definition compiled.h:40
+
bool compiled_check_contiguity(const std::vector< array > &inputs, const std::vector< int > &shape)
+
std::string build_lib_name(const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &tape, const std::unordered_set< uintptr_t > &constant_ids)
+
void print_constant(std::ostream &os, const array &x)
+
void print_float_constant(std::ostream &os, const array &x)
Definition compiled.h:28
+
void print_int_constant(std::ostream &os, const array &x)
Definition compiled.h:35
+
bool is_scalar(const array &x)
Definition compiled.h:52
+
void compiled_allocate_outputs(const std::vector< array > &inputs, std::vector< array > &outputs, const std::vector< array > &inputs_, const std::unordered_set< uintptr_t > &constant_ids_, bool contiguous, bool move_buffers=false)
+
std::string get_type_string(Dtype d)
+
bool is_static_cast(const Primitive &p)
Definition compiled.h:13
+ +
Definition dtype.h:15
+
+ + + + diff --git a/docs/build/html/conv_2loader_8h.html b/docs/build/html/conv_2loader_8h.html new file mode 100644 index 000000000..f5125528c --- /dev/null +++ b/docs/build/html/conv_2loader_8h.html @@ -0,0 +1,91 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/loader.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
loader.h File Reference
+
+ + + + + diff --git a/docs/build/html/conv_2loader_8h_source.html b/docs/build/html/conv_2loader_8h_source.html new file mode 100644 index 000000000..b3bb84bc8 --- /dev/null +++ b/docs/build/html/conv_2loader_8h_source.html @@ -0,0 +1,100 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/loader.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
loader.h
+
+ + + + + diff --git a/docs/build/html/conv_2params_8h.html b/docs/build/html/conv_2params_8h.html new file mode 100644 index 000000000..72736b010 --- /dev/null +++ b/docs/build/html/conv_2params_8h.html @@ -0,0 +1,111 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/params.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+ +
params.h File Reference
+
+
+ +

Go to the source code of this file.

+ + + + + + + + + + +

+Classes

struct  MLXConvParams< NDIM >
 
struct  mlx::steel::ImplicitGemmConv2DParams
 
struct  mlx::steel::Conv2DGeneralJumpParams
 
struct  mlx::steel::Conv2DGeneralBaseInfo
 
+ + + + + +

+Namespaces

namespace  mlx
 
namespace  mlx::steel
 
+
+ + + + diff --git a/docs/build/html/conv_2params_8h_source.html b/docs/build/html/conv_2params_8h_source.html new file mode 100644 index 000000000..deb4bf9a4 --- /dev/null +++ b/docs/build/html/conv_2params_8h_source.html @@ -0,0 +1,202 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/params.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
params.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+
5template <int NDIM>
+
+ +
7 const int N; // Batch size
+
8 const int C; // In channels
+
9 const int O; // Out channels
+
10 const int iS[NDIM]; // Input spatial dim
+
11 const int wS[NDIM]; // Weight spatial dim
+
12 const int oS[NDIM]; // Output spatial dim
+
13 const int str[NDIM]; // Kernel strides
+
14 const int pad[NDIM]; // Input padding
+
15 const int kdil[NDIM]; // Kernel dilation
+
16 const int idil[NDIM]; // Input dilation
+
17 const size_t in_strides[NDIM + 2]; // In strides
+
18 const size_t wt_strides[NDIM + 2]; // Wt strides
+
19 const size_t out_strides[NDIM + 2]; // Out strides
+
20 const int groups; // Input channel groups
+
21 const bool flip;
+
22};
+
+
23
+
24namespace mlx {
+
25namespace steel {
+
26
+
+ +
28 const int M;
+
29 const int N;
+
30 const int K;
+
31
+ +
33
+
34 const int inp_jump_w;
+
35 const int inp_jump_h;
+
36 const int inp_jump_c;
+
37
+
38 const int tiles_n;
+
39 const int tiles_m;
+
40 const int swizzle_log;
+
41};
+
+
42
+
+ +
44 const int f_wgt_jump_h;
+
45 const int f_wgt_jump_w;
+
46
+
47 const int f_out_jump_h;
+
48 const int f_out_jump_w;
+
49
+
50 const int adj_out_h;
+
51 const int adj_out_w;
+
52 const int adj_out_hw;
+
53 const int adj_implicit_m;
+
54};
+
+
55
+
+ + + +
59};
+
+
60
+
61} // namespace steel
+
62} // namespace mlx
+
Definition allocator.h:7
+
Definition params.h:6
+
const int C
Definition params.h:8
+
const size_t out_strides[NDIM+2]
Definition params.h:19
+
const int oS[NDIM]
Definition params.h:12
+
const int iS[NDIM]
Definition params.h:10
+
const int kdil[NDIM]
Definition params.h:15
+
const int str[NDIM]
Definition params.h:13
+
const size_t wt_strides[NDIM+2]
Definition params.h:18
+
const bool flip
Definition params.h:21
+
const size_t in_strides[NDIM+2]
Definition params.h:17
+
const int wS[NDIM]
Definition params.h:11
+
const int O
Definition params.h:9
+
const int N
Definition params.h:7
+
const int pad[NDIM]
Definition params.h:14
+
const int groups
Definition params.h:20
+
const int idil[NDIM]
Definition params.h:16
+
Definition params.h:56
+
int weight_base
Definition params.h:57
+
int weight_size
Definition params.h:58
+ +
const int f_out_jump_w
Definition params.h:48
+
const int f_wgt_jump_h
Definition params.h:44
+
const int f_wgt_jump_w
Definition params.h:45
+
const int adj_implicit_m
Definition params.h:53
+
const int f_out_jump_h
Definition params.h:47
+
const int adj_out_h
Definition params.h:50
+
const int adj_out_w
Definition params.h:51
+
const int adj_out_hw
Definition params.h:52
+ +
const int inp_jump_h
Definition params.h:35
+
const int M
Definition params.h:28
+
const int N
Definition params.h:29
+
const int tiles_m
Definition params.h:39
+
const int tiles_n
Definition params.h:38
+
const int inp_jump_c
Definition params.h:36
+
const int gemm_k_iterations
Definition params.h:32
+
const int inp_jump_w
Definition params.h:34
+
const int swizzle_log
Definition params.h:40
+
const int K
Definition params.h:30
+
+ + + + diff --git a/docs/build/html/conv_8h.html b/docs/build/html/conv_8h.html new file mode 100644 index 000000000..e62539e7c --- /dev/null +++ b/docs/build/html/conv_8h.html @@ -0,0 +1,92 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/conv.h File Reference + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
conv.h File Reference
+
+ + + + + diff --git a/docs/build/html/conv_8h_source.html b/docs/build/html/conv_8h_source.html new file mode 100644 index 000000000..a7219c852 --- /dev/null +++ b/docs/build/html/conv_8h_source.html @@ -0,0 +1,108 @@ + + + + + + + +MLX: mlx/backend/metal/kernels/steel/conv/conv.h Source File + + + + + + + + + + + +
+
+ + + + + + +
+
MLX +
+
+
+ + + + + + + + + +
+
+ + +
+
+
+
+
+
Loading...
+
Searching...
+
No Matches
+
+
+
+
+ + +
+
+
conv.h
+
+
+Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
+
2
+
3#pragma once
+
4
+ +
6
+ + +
9
+
10using namespace metal;
+
11using namespace mlx::steel;
+ + + +
Definition bf16.h:265
+
Definition loader_channel_l.h:14
+
+ + + + diff --git a/docs/build/html/cookie.js b/docs/build/html/cookie.js new file mode 100644 index 000000000..53ad21d98 --- /dev/null +++ b/docs/build/html/cookie.js @@ -0,0 +1,58 @@ +/*! + Cookie helper functions + Copyright (c) 2023 Dimitri van Heesch + Released under MIT license. +*/ +let Cookie = { + cookie_namespace: 'doxygen_', + + readSetting(cookie,defVal) { + if (window.chrome) { + const val = localStorage.getItem(this.cookie_namespace+cookie) || + sessionStorage.getItem(this.cookie_namespace+cookie); + if (val) return val; + } else { + let myCookie = this.cookie_namespace+cookie+"="; + if (document.cookie) { + const index = document.cookie.indexOf(myCookie); + if (index != -1) { + const valStart = index + myCookie.length; + let valEnd = document.cookie.indexOf(";", valStart); + if (valEnd == -1) { + valEnd = document.cookie.length; + } + return document.cookie.substring(valStart, valEnd); + } + } + } + return defVal; + }, + + writeSetting(cookie,val,days=10*365) { // default days='forever', 0=session cookie, -1=delete + if (window.chrome) { + if (days==0) { + sessionStorage.setItem(this.cookie_namespace+cookie,val); + } else { + localStorage.setItem(this.cookie_namespace+cookie,val); + } + } else { + let date = new Date(); + date.setTime(date.getTime()+(days*24*60*60*1000)); + const expiration = days!=0 ? "expires="+date.toGMTString()+";" : ""; + document.cookie = this.cookie_namespace + cookie + "=" + + val + "; SameSite=Lax;" + expiration + "path=/"; + } + }, + + eraseSetting(cookie) { + if (window.chrome) { + if (localStorage.getItem(this.cookie_namespace+cookie)) { + localStorage.removeItem(this.cookie_namespace+cookie); + } else if (sessionStorage.getItem(this.cookie_namespace+cookie)) { + sessionStorage.removeItem(this.cookie_namespace+cookie); + } + } else { + this.writeSetting(cookie,'',-1); + } + }, +} diff --git a/docs/build/html/cpp/ops.html b/docs/build/html/cpp/ops.html index 159c599d7..a43770828 100644 --- a/docs/build/html/cpp/ops.html +++ b/docs/build/html/cpp/ops.html @@ -8,7 +8,7 @@ - Operations — MLX 0.12.0 documentation + Operations — MLX 0.13.0 documentation @@ -36,7 +36,7 @@ - + @@ -44,7 +44,7 @@ - + @@ -131,8 +131,8 @@ - MLX 0.12.0 documentation - Home - + MLX 0.13.0 documentation - Home + @@ -255,6 +255,7 @@
  • mlx.core.arcsin
  • mlx.core.arcsinh
  • mlx.core.arctan
  • +
  • mlx.core.arctan2
  • mlx.core.arctanh
  • mlx.core.argmax
  • mlx.core.argmin
  • @@ -264,11 +265,17 @@
  • mlx.core.atleast_1d
  • mlx.core.atleast_2d
  • mlx.core.atleast_3d
  • -
  • mlx.core.broadcast_to
  • +
  • mlx.core.bitwise_and
  • +
  • mlx.core.bitwise_or
  • +
  • mlx.core.bitwise_xor
  • mlx.core.block_masked_mm
  • +
  • mlx.core.block_sparse_mm
  • +
  • mlx.core.broadcast_to
  • mlx.core.ceil
  • mlx.core.clip
  • mlx.core.concatenate
  • +
  • mlx.core.conj
  • +
  • mlx.core.conjugate
  • mlx.core.convolve
  • mlx.core.conv1d
  • mlx.core.conv2d
  • @@ -305,6 +312,7 @@
  • mlx.core.isnan
  • mlx.core.isneginf
  • mlx.core.isposinf
  • +
  • mlx.core.left_shift
  • mlx.core.less
  • mlx.core.less_equal
  • mlx.core.linspace
  • @@ -341,6 +349,7 @@
  • mlx.core.reciprocal
  • mlx.core.repeat
  • mlx.core.reshape
  • +
  • mlx.core.right_shift
  • mlx.core.round
  • mlx.core.rsqrt
  • mlx.core.save
  • @@ -436,8 +445,10 @@
  • Metal
  • +
  • mlx.optimizers.clip_grad_norm
  • Tree Utils
  • @@ -751,7 +764,9 @@ document.write(` `); - + @@ -767,6 +782,298 @@ document.write(`
    +
    +

    Contents

    +
    +
    @@ -778,6 +1085,1617 @@ document.write(`

    Operations#

    +
    +
    +array arange(double start, double stop, double step, Dtype dtype, StreamOrDevice s = {})#
    +

    A 1D array of numbers starting at start (optional), stopping at stop, stepping by step (optional).

    +
    + +
    +
    +array arange(double start, double stop, double step, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double start, double stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double stop, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(double stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(int start, int stop, int step, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(int start, int stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array arange(int stop, StreamOrDevice s = {})#
    +
    + +
    +
    +array linspace(double start, double stop, int num = 50, Dtype dtype = float32, StreamOrDevice s = {})#
    +

    A 1D array of num evenly spaced numbers in the range [start, stop]

    +
    + +
    +
    +array astype(array a, Dtype dtype, StreamOrDevice s = {})#
    +

    Convert an array to the given data type.

    +
    + +
    +
    +array as_strided(array a, std::vector<int> shape, std::vector<size_t> strides, size_t offset, StreamOrDevice s = {})#
    +

    Create a view of an array with the given shape and strides.

    +
    + +
    +
    +array copy(array a, StreamOrDevice s = {})#
    +

    Copy another array.

    +
    + +
    +
    +array full(std::vector<int> shape, array vals, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape with the given value(s).

    +
    + +
    +
    +array full(std::vector<int> shape, array vals, StreamOrDevice s = {})#
    +
    + +
    +
    +template<typename T>
    array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +template<typename T>
    array full(std::vector<int> shape, T val, StreamOrDevice s = {})#
    +
    + +
    +
    +array zeros(const std::vector<int> &shape, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape with zeros.

    +
    + +
    +
    +inline array zeros(const std::vector<int> &shape, StreamOrDevice s = {})#
    +
    + +
    +
    +array zeros_like(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array ones(const std::vector<int> &shape, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape with ones.

    +
    + +
    +
    +inline array ones(const std::vector<int> &shape, StreamOrDevice s = {})#
    +
    + +
    +
    +array ones_like(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {})#
    +

    Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.

    +
    + +
    +
    +inline array eye(int n, Dtype dtype, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array eye(int n, int m, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array eye(int n, int m, int k, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array eye(int n, StreamOrDevice s = {})#
    +
    + +
    +
    +array identity(int n, Dtype dtype, StreamOrDevice s = {})#
    +

    Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.

    +
    + +
    +
    +inline array identity(int n, StreamOrDevice s = {})#
    +
    + +
    +
    +array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {})#
    +
    + +
    +
    +inline array tri(int n, Dtype type, StreamOrDevice s = {})#
    +
    + +
    +
    +array tril(array x, int k = 0, StreamOrDevice s = {})#
    +
    + +
    +
    +array triu(array x, int k = 0, StreamOrDevice s = {})#
    +
    + +
    +
    +array reshape(const array &a, std::vector<int> shape, StreamOrDevice s = {})#
    +

    Reshape an array to the given shape.

    +
    + +
    +
    +array flatten(const array &a, int start_axis, int end_axis = -1, StreamOrDevice s = {})#
    +

    Flatten the dimensions in the range [start_axis, end_axis] .

    +
    + +
    +
    +array flatten(const array &a, StreamOrDevice s = {})#
    +

    Flatten the array to 1D.

    +
    + +
    +
    +array squeeze(const array &a, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Remove singleton dimensions at the given axes.

    +
    + +
    +
    +inline array squeeze(const array &a, int axis, StreamOrDevice s = {})#
    +

    Remove singleton dimensions at the given axis.

    +
    + +
    +
    +array squeeze(const array &a, StreamOrDevice s = {})#
    +

    Remove all singleton dimensions.

    +
    + +
    +
    +array expand_dims(const array &a, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Add a singleton dimension at the given axes.

    +
    + +
    +
    +array expand_dims(const array &a, int axis, StreamOrDevice s = {})#
    +

    Add a singleton dimension at the given axis.

    +
    + +
    +
    +array slice(const array &a, std::vector<int> start, std::vector<int> stop, std::vector<int> strides, StreamOrDevice s = {})#
    +

    Slice an array.

    +
    + +
    +
    +array slice(const array &a, const std::vector<int> &start, const std::vector<int> &stop, StreamOrDevice s = {})#
    +

    Slice an array with a stride of 1 in each dimension.

    +
    + +
    +
    +array slice_update(const array &src, const array &update, std::vector<int> start, std::vector<int> stop, std::vector<int> strides, StreamOrDevice s = {})#
    +

    Update a slice from the source array.

    +
    + +
    +
    +array slice_update(const array &src, const array &update, std::vector<int> start, std::vector<int> stop, StreamOrDevice s = {})#
    +

    Update a slice from the source array with stride 1 in each dimension.

    +
    + +
    +
    +std::vector<array> split(const array &a, int num_splits, int axis, StreamOrDevice s = {})#
    +

    Split an array into sub-arrays along a given axis.

    +
    + +
    +
    +std::vector<array> split(const array &a, int num_splits, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> split(const array &a, const std::vector<int> &indices, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> split(const array &a, const std::vector<int> &indices, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> meshgrid(const std::vector<array> &arrays, bool sparse = false, std::string indexing = "xy", StreamOrDevice s = {})#
    +

    A vector of coordinate arrays from coordinate vectors.

    +
    + +
    +
    +array clip(const array &a, const std::optional<array> &a_min = std::nullopt, const std::optional<array> &a_max = std::nullopt, StreamOrDevice s = {})#
    +

    Clip (limit) the values in an array.

    +
    + +
    +
    +array concatenate(const std::vector<array> &arrays, int axis, StreamOrDevice s = {})#
    +

    Concatenate arrays along a given axis.

    +
    + +
    +
    +array concatenate(const std::vector<array> &arrays, StreamOrDevice s = {})#
    +
    + +
    +
    +array stack(const std::vector<array> &arrays, int axis, StreamOrDevice s = {})#
    +

    Stack arrays along a new axis.

    +
    + +
    +
    +array stack(const std::vector<array> &arrays, StreamOrDevice s = {})#
    +
    + +
    +
    +array repeat(const array &arr, int repeats, int axis, StreamOrDevice s = {})#
    +

    Repeat an array along an axis.

    +
    + +
    +
    +array repeat(const array &arr, int repeats, StreamOrDevice s = {})#
    +
    + +
    +
    +array tile(const array &arr, std::vector<int> reps, StreamOrDevice s = {})#
    +
    + +
    +
    +array transpose(const array &a, std::vector<int> axes, StreamOrDevice s = {})#
    +

    Permutes the dimensions according to the given axes.

    +
    + +
    +
    +inline array transpose(const array &a, std::initializer_list<int> axes, StreamOrDevice s = {})#
    +
    + +
    +
    +array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s = {})#
    +

    Swap two axes of an array.

    +
    + +
    +
    +array moveaxis(const array &a, int source, int destination, StreamOrDevice s = {})#
    +

    Move an axis of an array.

    +
    + +
    +
    +array pad(const array &a, const std::vector<int> &axes, const std::vector<int> &low_pad_size, const std::vector<int> &high_pad_size, const array &pad_value = array(0), StreamOrDevice s = {})#
    +

    Pad an array with a constant value.

    +
    + +
    +
    +array pad(const array &a, const std::vector<std::pair<int, int>> &pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
    +

    Pad an array with a constant value along all axes.

    +
    + +
    +
    +array pad(const array &a, const std::pair<int, int> &pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
    +
    + +
    +
    +array pad(const array &a, int pad_width, const array &pad_value = array(0), StreamOrDevice s = {})#
    +
    + +
    +
    +array transpose(const array &a, StreamOrDevice s = {})#
    +

    Permutes the dimensions in reverse order.

    +
    + +
    +
    +array broadcast_to(const array &a, const std::vector<int> &shape, StreamOrDevice s = {})#
    +

    Broadcast an array to a given shape.

    +
    + +
    +
    +std::vector<array> broadcast_arrays(const std::vector<array> &inputs, StreamOrDevice s = {})#
    +

    Broadcast a vector of arrays against one another.

    +
    + +
    +
    +array equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns the bool array with (a == b) element-wise.

    +
    + +
    +
    +inline array operator==(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator==(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator==(const array &a, T b)#
    +
    + +
    +
    +array not_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns the bool array with (a != b) element-wise.

    +
    + +
    +
    +inline array operator!=(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator!=(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator!=(const array &a, T b)#
    +
    + +
    +
    +array greater(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a > b) element-wise.

    +
    + +
    +
    +inline array operator>(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>(const array &a, T b)#
    +
    + +
    +
    +array greater_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a >= b) element-wise.

    +
    + +
    +
    +inline array operator>=(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>=(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator>=(const array &a, T b)#
    +
    + +
    +
    +array less(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a < b) element-wise.

    +
    + +
    +
    +inline array operator<(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<(const array &a, T b)#
    +
    + +
    +
    +array less_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Returns bool array with (a <= b) element-wise.

    +
    + +
    +
    +inline array operator<=(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<=(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator<=(const array &a, T b)#
    +
    + +
    +
    +array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s = {})#
    +

    True if two arrays have the same shape and elements.

    +
    + +
    +
    +inline array array_equal(const array &a, const array &b, StreamOrDevice s = {})#
    +
    + +
    +
    +array isnan(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array isinf(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array isposinf(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array isneginf(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array where(const array &condition, const array &x, const array &y, StreamOrDevice s = {})#
    +

    Select from x or y depending on condition.

    +
    + +
    +
    +array all(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    True if all elements in the array are true (or non-zero).

    +
    + +
    +
    +inline array all(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array allclose(const array &a, const array &b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false, StreamOrDevice s = {})#
    +

    True if the two arrays are equal within the specified tolerance.

    +
    + +
    +
    +array isclose(const array &a, const array &b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false, StreamOrDevice s = {})#
    +

    Returns a boolean array where two arrays are element-wise equal within the specified tolerance.

    +
    + +
    +
    +array all(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axes.

    +

    An output value is true if all the corresponding inputs are true.

    +
    + +
    +
    +array all(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axis.

    +

    An output value is true if all the corresponding inputs are true.

    +
    + +
    +
    +array any(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    True if any elements in the array are true (or non-zero).

    +
    + +
    +
    +inline array any(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array any(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axes.

    +

    An output value is true if any of the corresponding inputs are true.

    +
    + +
    +
    +array any(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Reduces the input along the given axis.

    +

    An output value is true if any of the corresponding inputs are true.

    +
    + +
    +
    +array sum(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Sums the elements of an array.

    +
    + +
    +
    +inline array sum(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array sum(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Sums the elements of an array along the given axes.

    +
    + +
    +
    +array sum(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Sums the elements of an array along the given axis.

    +
    + +
    +
    +array mean(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Computes the mean of the elements of an array.

    +
    + +
    +
    +inline array mean(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array mean(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    Computes the mean of the elements of an array along the given axes.

    +
    + +
    +
    +array mean(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Computes the mean of the elements of an array along the given axis.

    +
    + +
    +
    +array var(const array &a, bool keepdims, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the variance of the elements of an array.

    +
    + +
    +
    +inline array var(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array var(const array &a, const std::vector<int> &axes, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the variance of the elements of an array along the given axes.

    +
    + +
    +
    +array var(const array &a, int axis, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the variance of the elements of an array along the given axis.

    +
    + +
    +
    +array std(const array &a, bool keepdims, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the standard deviation of the elements of an array.

    +
    + +
    +
    +inline array std(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array std(const array &a, const std::vector<int> &axes, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the standard deviatoin of the elements of an array along the given axes.

    +
    + +
    +
    +array std(const array &a, int axis, bool keepdims = false, int ddof = 0, StreamOrDevice s = {})#
    +

    Computes the standard deviation of the elements of an array along the given axis.

    +
    + +
    +
    +array prod(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The product of all elements of the array.

    +
    + +
    +
    +inline array prod(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array prod(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The product of the elements of an array along the given axes.

    +
    + +
    +
    +array prod(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The product of the elements of an array along the given axis.

    +
    + +
    +
    +array max(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The maximum of all elements of the array.

    +
    + +
    +
    +inline array max(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array max(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The maximum of the elements of an array along the given axes.

    +
    + +
    +
    +array max(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The maximum of the elements of an array along the given axis.

    +
    + +
    +
    +array min(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The minimum of all elements of the array.

    +
    + +
    +
    +inline array min(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array min(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The minimum of the elements of an array along the given axes.

    +
    + +
    +
    +array min(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The minimum of the elements of an array along the given axis.

    +
    + +
    +
    +array argmin(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Returns the index of the minimum value in the array.

    +
    + +
    +
    +inline array argmin(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array argmin(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Returns the indices of the minimum values along a given axis.

    +
    + +
    +
    +array argmax(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    Returns the index of the maximum value in the array.

    +
    + +
    +
    +inline array argmax(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array argmax(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    Returns the indices of the maximum values along a given axis.

    +
    + +
    +
    +array sort(const array &a, StreamOrDevice s = {})#
    +

    Returns a sorted copy of the flattened array.

    +
    + +
    +
    +array sort(const array &a, int axis, StreamOrDevice s = {})#
    +

    Returns a sorted copy of the array along a given axis.

    +
    + +
    +
    +array argsort(const array &a, StreamOrDevice s = {})#
    +

    Returns indices that sort the flattened array.

    +
    + +
    +
    +array argsort(const array &a, int axis, StreamOrDevice s = {})#
    +

    Returns indices that sort the array along a given axis.

    +
    + +
    +
    +array partition(const array &a, int kth, StreamOrDevice s = {})#
    +

    Returns a partitioned copy of the flattened array such that the smaller kth elements are first.

    +
    + +
    +
    +array partition(const array &a, int kth, int axis, StreamOrDevice s = {})#
    +

    Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.

    +
    + +
    +
    +array argpartition(const array &a, int kth, StreamOrDevice s = {})#
    +

    Returns indices that partition the flattened array such that the smaller kth elements are first.

    +
    + +
    +
    +array argpartition(const array &a, int kth, int axis, StreamOrDevice s = {})#
    +

    Returns indices that partition the array along a given axis such that the smaller kth elements are first.

    +
    + +
    +
    +array topk(const array &a, int k, StreamOrDevice s = {})#
    +

    Returns topk elements of the flattened array.

    +
    + +
    +
    +array topk(const array &a, int k, int axis, StreamOrDevice s = {})#
    +

    Returns topk elements of the array along a given axis.

    +
    + +
    +
    +array logsumexp(const array &a, bool keepdims, StreamOrDevice s = {})#
    +

    The logsumexp of all elements of the array.

    +
    + +
    +
    +inline array logsumexp(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array logsumexp(const array &a, const std::vector<int> &axes, bool keepdims = false, StreamOrDevice s = {})#
    +

    The logsumexp of the elements of an array along the given axes.

    +
    + +
    +
    +array logsumexp(const array &a, int axis, bool keepdims = false, StreamOrDevice s = {})#
    +

    The logsumexp of the elements of an array along the given axis.

    +
    + +
    +
    +array abs(const array &a, StreamOrDevice s = {})#
    +

    Absolute value of elements in an array.

    +
    + +
    +
    +array negative(const array &a, StreamOrDevice s = {})#
    +

    Negate an array.

    +
    + +
    +
    +array operator-(const array &a)#
    +
    + +
    +
    +array sign(const array &a, StreamOrDevice s = {})#
    +

    The sign of the elements in an array.

    +
    + +
    +
    +array logical_not(const array &a, StreamOrDevice s = {})#
    +

    Logical not of an array.

    +
    + +
    +
    +array logical_and(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Logical and of two arrays.

    +
    + +
    +
    +array operator&&(const array &a, const array &b)#
    +
    + +
    +
    +array logical_or(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Logical or of two arrays.

    +
    + +
    +
    +array operator||(const array &a, const array &b)#
    +
    + +
    +
    +array reciprocal(const array &a, StreamOrDevice s = {})#
    +

    The reciprocal (1/x) of the elements in an array.

    +
    + +
    +
    +array add(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Add two arrays.

    +
    + +
    +
    +array operator+(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator+(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator+(const array &a, T b)#
    +
    + +
    +
    +array subtract(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Subtract two arrays.

    +
    + +
    +
    +array operator-(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator-(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator-(const array &a, T b)#
    +
    + +
    +
    +array multiply(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Multiply two arrays.

    +
    + +
    +
    +array operator*(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator*(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator*(const array &a, T b)#
    +
    + +
    +
    +array divide(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Divide two arrays.

    +
    + +
    +
    +array operator/(const array &a, const array &b)#
    +
    + +
    +
    +array operator/(double a, const array &b)#
    +
    + +
    +
    +array operator/(const array &a, double b)#
    +
    + +
    +
    +std::vector<array> divmod(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the element-wise quotient and remainder.

    +
    + +
    +
    +array floor_divide(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute integer division.

    +

    Equivalent to doing floor(a / x).

    +
    + +
    +
    +array remainder(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the element-wise remainder of division.

    +
    + +
    +
    +array operator%(const array &a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator%(T a, const array &b)#
    +
    + +
    +
    +template<typename T>
    array operator%(const array &a, T b)#
    +
    + +
    +
    +array maximum(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Element-wise maximum between two arrays.

    +
    + +
    +
    +array minimum(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Element-wise minimum between two arrays.

    +
    + +
    +
    +array floor(const array &a, StreamOrDevice s = {})#
    +

    Floor the element of an array.

    +
    + +
    +
    +array ceil(const array &a, StreamOrDevice s = {})#
    +

    Ceil the element of an array.

    +
    + +
    +
    +array square(const array &a, StreamOrDevice s = {})#
    +

    Square the elements of an array.

    +
    + +
    +
    +array exp(const array &a, StreamOrDevice s = {})#
    +

    Exponential of the elements of an array.

    +
    + +
    +
    +array sin(const array &a, StreamOrDevice s = {})#
    +

    Sine of the elements of an array.

    +
    + +
    +
    +array cos(const array &a, StreamOrDevice s = {})#
    +

    Cosine of the elements of an array.

    +
    + +
    +
    +array tan(const array &a, StreamOrDevice s = {})#
    +

    Tangent of the elements of an array.

    +
    + +
    +
    +array arcsin(const array &a, StreamOrDevice s = {})#
    +

    Arc Sine of the elements of an array.

    +
    + +
    +
    +array arccos(const array &a, StreamOrDevice s = {})#
    +

    Arc Cosine of the elements of an array.

    +
    + +
    +
    +array arctan(const array &a, StreamOrDevice s = {})#
    +

    Arc Tangent of the elements of an array.

    +
    + +
    +
    +array arctan2(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Inverse tangent of the ratio of two arrays.

    +
    + +
    +
    +array sinh(const array &a, StreamOrDevice s = {})#
    +

    Hyperbolic Sine of the elements of an array.

    +
    + +
    +
    +array cosh(const array &a, StreamOrDevice s = {})#
    +

    Hyperbolic Cosine of the elements of an array.

    +
    + +
    +
    +array tanh(const array &a, StreamOrDevice s = {})#
    +

    Hyperbolic Tangent of the elements of an array.

    +
    + +
    +
    +array arcsinh(const array &a, StreamOrDevice s = {})#
    +

    Inverse Hyperbolic Sine of the elements of an array.

    +
    + +
    +
    +array arccosh(const array &a, StreamOrDevice s = {})#
    +

    Inverse Hyperbolic Cosine of the elements of an array.

    +
    + +
    +
    +array arctanh(const array &a, StreamOrDevice s = {})#
    +

    Inverse Hyperbolic Tangent of the elements of an array.

    +
    + +
    +
    +array degrees(const array &a, StreamOrDevice s = {})#
    +

    Convert the elements of an array from Radians to Degrees.

    +
    + +
    +
    +array radians(const array &a, StreamOrDevice s = {})#
    +

    Convert the elements of an array from Degrees to Radians.

    +
    + +
    +
    +array log(const array &a, StreamOrDevice s = {})#
    +

    Natural logarithm of the elements of an array.

    +
    + +
    +
    +array log2(const array &a, StreamOrDevice s = {})#
    +

    Log base 2 of the elements of an array.

    +
    + +
    +
    +array log10(const array &a, StreamOrDevice s = {})#
    +

    Log base 10 of the elements of an array.

    +
    + +
    +
    +array log1p(const array &a, StreamOrDevice s = {})#
    +

    Natural logarithm of one plus elements in the array: log(1 + a).

    +
    + +
    +
    +array logaddexp(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Log-add-exp of one elements in the array: log(exp(a) + exp(b)).

    +
    + +
    +
    +array sigmoid(const array &a, StreamOrDevice s = {})#
    +

    Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).

    +
    + +
    +
    +array erf(const array &a, StreamOrDevice s = {})#
    +

    Computes the error function of the elements of an array.

    +
    + +
    +
    +array erfinv(const array &a, StreamOrDevice s = {})#
    +

    Computes the inverse error function of the elements of an array.

    +
    + +
    +
    +array expm1(const array &a, StreamOrDevice s = {})#
    +

    Computes the expm1 function of the elements of an array.

    +
    + +
    +
    +array stop_gradient(const array &a, StreamOrDevice s = {})#
    +

    Stop the flow of gradients.

    +
    + +
    +
    +array round(const array &a, int decimals, StreamOrDevice s = {})#
    +

    Round a floating point number.

    +
    + +
    +
    +inline array round(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array matmul(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Matrix-matrix multiplication.

    +
    + +
    +
    +array gather(const array &a, const std::vector<array> &indices, const std::vector<int> &axes, const std::vector<int> &slice_sizes, StreamOrDevice s = {})#
    +

    Gather array entries given indices and slices.

    +
    + +
    +
    +inline array gather(const array &a, const array &indices, int axis, const std::vector<int> &slice_sizes, StreamOrDevice s = {})#
    +
    + +
    +
    +array take(const array &a, const array &indices, int axis, StreamOrDevice s = {})#
    +

    Take array slices at the given indices of the specified axis.

    +
    + +
    +
    +array take(const array &a, const array &indices, StreamOrDevice s = {})#
    +

    Take array entries at the given indices treating the array as flattened.

    +
    + +
    +
    +array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s = {})#
    +

    Take array entries given indices along the axis.

    +
    + +
    +
    +array scatter(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter updates to given linear indices.

    +
    + +
    +
    +inline array scatter(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_add(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and add updates to given indices.

    +
    + +
    +
    +inline array scatter_add(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_prod(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and prod updates to given indices.

    +
    + +
    +
    +inline array scatter_prod(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_max(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and max updates to given linear indices.

    +
    + +
    +
    +inline array scatter_max(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array scatter_min(const array &a, const std::vector<array> &indices, const array &updates, const std::vector<int> &axes, StreamOrDevice s = {})#
    +

    Scatter and min updates to given linear indices.

    +
    + +
    +
    +inline array scatter_min(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s = {})#
    +
    + +
    +
    +array sqrt(const array &a, StreamOrDevice s = {})#
    +

    Square root the elements of an array.

    +
    + +
    +
    +array rsqrt(const array &a, StreamOrDevice s = {})#
    +

    Square root and reciprocal the elements of an array.

    +
    + +
    +
    +array softmax(const array &a, const std::vector<int> &axes, bool precise = false, StreamOrDevice s = {})#
    +

    Softmax of an array.

    +
    + +
    +
    +array softmax(const array &a, bool precise = false, StreamOrDevice s = {})#
    +

    Softmax of an array.

    +
    + +
    +
    +inline array softmax(const array &a, int axis, bool precise = false, StreamOrDevice s = {})#
    +

    Softmax of an array.

    +
    + +
    +
    +array power(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Raise elements of a to the power of b element-wise.

    +
    + +
    +
    +array cumsum(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative sum of an array.

    +
    + +
    +
    +array cumprod(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative product of an array.

    +
    + +
    +
    +array cummax(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative max of an array.

    +
    + +
    +
    +array cummin(const array &a, int axis, bool reverse = false, bool inclusive = true, StreamOrDevice s = {})#
    +

    Cumulative min of an array.

    +
    + +
    +
    +array conv_general(array input, array weight, std::vector<int> stride = {}, std::vector<int> padding_lo = {}, std::vector<int> padding_hi = {}, std::vector<int> kernel_dilation = {}, std::vector<int> input_dilation = {}, int groups = 1, bool flip = false, StreamOrDevice s = {})#
    +

    General convolution with a filter.

    +
    + +
    +
    +inline array conv_general(const array &input, const array &weight, std::vector<int> stride = {}, std::vector<int> padding = {}, std::vector<int> kernel_dilation = {}, std::vector<int> input_dilation = {}, int groups = 1, bool flip = false, StreamOrDevice s = {})#
    +

    General convolution with a filter.

    +
    + +
    +
    +array conv1d(const array &input, const array &weight, int stride = 1, int padding = 0, int dilation = 1, int groups = 1, StreamOrDevice s = {})#
    +

    1D convolution with a filter

    +
    + +
    +
    +array conv2d(const array &input, const array &weight, const std::pair<int, int> &stride = {1, 1}, const std::pair<int, int> &padding = {0, 0}, const std::pair<int, int> &dilation = {1, 1}, int groups = 1, StreamOrDevice s = {})#
    +

    2D convolution with a filter

    +
    + +
    +
    +array quantized_matmul(const array &x, const array &w, const array &scales, const array &biases, bool transpose = true, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
    +

    Quantized matmul multiplies x with a quantized matrix w.

    +
    + +
    +
    +std::tuple<array, array, array> quantize(const array &w, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
    +

    Quantize a matrix along its last axis.

    +
    + +
    +
    +array dequantize(const array &w, const array &scales, const array &biases, int group_size = 64, int bits = 4, StreamOrDevice s = {})#
    +

    Dequantize a matrix produced by quantize()

    +
    + +
    +
    +array tensordot(const array &a, const array &b, const int axis = 2, StreamOrDevice s = {})#
    +

    Returns a contraction of a and b over multiple dimensions.

    +
    + +
    +
    +array tensordot(const array &a, const array &b, const std::vector<int> &axes_a, const std::vector<int> &axes_b, StreamOrDevice s = {})#
    +
    + +
    +
    +array outer(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the outer product of two vectors.

    +
    + +
    +
    +array inner(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Compute the inner product of two vectors.

    +
    + +
    +
    +array addmm(array c, array a, array b, const float &alpha = 1.f, const float &beta = 1.f, StreamOrDevice s = {})#
    +

    Compute D = beta * C + alpha * (A @ B)

    +
    + +
    +
    +array block_masked_mm(array a, array b, int block_size, std::optional<array> mask_out = std::nullopt, std::optional<array> mask_lhs = std::nullopt, std::optional<array> mask_rhs = std::nullopt, StreamOrDevice s = {})#
    +

    Compute matrix product with block masking.

    +
    + +
    +
    +array block_sparse_mm(array a, array b, std::optional<array> lhs_indices = std::nullopt, std::optional<array> rhs_indices = std::nullopt, StreamOrDevice s = {})#
    +

    Compute matrix product with matrix-level gather.

    +
    + +
    +
    +array diagonal(const array &a, int offset = 0, int axis1 = 0, int axis2 = 1, StreamOrDevice s = {})#
    +

    Extract a diagonal or construct a diagonal array.

    +
    + +
    +
    +array diag(const array &a, int k = 0, StreamOrDevice s = {})#
    +

    Extract diagonal from a 2d array or create a diagonal matrix.

    +
    + +
    +
    +std::vector<array> depends(const std::vector<array> &inputs, const std::vector<array> &dependencies)#
    +

    Implements the identity function but allows injecting dependencies to other arrays.

    +

    This ensures that these other arrays will have been computed when the outputs of this function are computed.

    +
    + +
    +
    +array atleast_1d(const array &a, StreamOrDevice s = {})#
    +

    convert an array to an atleast ndim array

    +
    + +
    +
    +std::vector<array> atleast_1d(const std::vector<array> &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array atleast_2d(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> atleast_2d(const std::vector<array> &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array atleast_3d(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +std::vector<array> atleast_3d(const std::vector<array> &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array number_of_elements(const array &a, std::vector<int> axes, bool inverted, Dtype dtype = int32, StreamOrDevice s = {})#
    +

    Extract the number of elements along some axes as a scalar array.

    +

    Used to allow shape dependent shapeless compilation (pun intended).

    +
    + +
    +
    +array conjugate(const array &a, StreamOrDevice s = {})#
    +
    + +
    +
    +array bitwise_and(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Bitwise and.

    +
    + +
    +
    +array operator&(const array &a, const array &b)#
    +
    + +
    +
    +array bitwise_or(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Bitwise inclusive or.

    +
    + +
    +
    +array operator|(const array &a, const array &b)#
    +
    + +
    +
    +array bitwise_xor(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Bitwise exclusive or.

    +
    + +
    +
    +array operator^(const array &a, const array &b)#
    +
    + +
    +
    +array left_shift(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Shift bits to the left.

    +
    + +
    +
    +array operator<<(const array &a, const array &b)#
    +
    + +
    +
    +array right_shift(const array &a, const array &b, StreamOrDevice s = {})#
    +

    Shift bits to the right.

    +
    + +
    +
    +array operator>>(const array &a, const array &b)#
    +
    +
    @@ -792,12 +2710,12 @@ document.write(`

    previous

    -

    mlx.utils.tree_map_with_path

    +

    mlx.utils.tree_reduce

    +