mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00

* spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
63 lines
2.5 KiB
Metal
63 lines
2.5 KiB
Metal
// Copyright © 2023 Apple Inc.
|
|
|
|
#include <metal_stdlib>
|
|
|
|
#include "mlx/backend/metal/kernels/bf16.h"
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
|
|
|
template <typename T>
|
|
[[kernel]] void axpby_general(
|
|
device const T* x [[buffer(0)]],
|
|
device const T* y [[buffer(1)]],
|
|
device T* out [[buffer(2)]],
|
|
constant const float& alpha [[buffer(3)]],
|
|
constant const float& beta [[buffer(4)]],
|
|
constant const int* shape [[buffer(5)]],
|
|
constant const size_t* x_strides [[buffer(6)]],
|
|
constant const size_t* y_strides [[buffer(7)]],
|
|
constant const int& ndim [[buffer(8)]],
|
|
uint index [[thread_position_in_grid]]) {
|
|
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
|
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
|
out[index] =
|
|
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
|
}
|
|
|
|
template <typename T>
|
|
[[kernel]] void axpby_contiguous(
|
|
device const T* x [[buffer(0)]],
|
|
device const T* y [[buffer(1)]],
|
|
device T* out [[buffer(2)]],
|
|
constant const float& alpha [[buffer(3)]],
|
|
constant const float& beta [[buffer(4)]],
|
|
uint index [[thread_position_in_grid]]) {
|
|
out[index] =
|
|
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
|
}
|
|
|
|
#define instantiate_axpby(type_name, type) \
|
|
template [[host_name("axpby_general_" #type_name)]] \
|
|
[[kernel]] void axpby_general<type>( \
|
|
device const type* x [[buffer(0)]], \
|
|
device const type* y [[buffer(1)]], \
|
|
device type* out [[buffer(2)]], \
|
|
constant const float& alpha [[buffer(3)]], \
|
|
constant const float& beta [[buffer(4)]], \
|
|
constant const int* shape [[buffer(5)]], \
|
|
constant const size_t* x_strides [[buffer(6)]], \
|
|
constant const size_t* y_strides [[buffer(7)]], \
|
|
constant const int& ndim [[buffer(8)]], \
|
|
uint index [[thread_position_in_grid]]); \
|
|
template [[host_name("axpby_contiguous_" #type_name)]] \
|
|
[[kernel]] void axpby_contiguous<type>( \
|
|
device const type* x [[buffer(0)]], \
|
|
device const type* y [[buffer(1)]], \
|
|
device type* out [[buffer(2)]], \
|
|
constant const float& alpha [[buffer(3)]], \
|
|
constant const float& beta [[buffer(4)]], \
|
|
uint index [[thread_position_in_grid]]);
|
|
|
|
instantiate_axpby(float32, float);
|
|
instantiate_axpby(float16, half);
|
|
instantiate_axpby(bfloat16, bfloat16_t);
|
|
instantiate_axpby(complex64, complex64_t); |