mlx/mlx/fast.h
Jagrit Digani 3290bfa690
Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00

93 lines
2.1 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <optional>
#include <variant>
#include "mlx/utils.h"
namespace mlx::core::fast {
array rms_norm(
const array& x,
const std::optional<array>& weight,
float eps,
StreamOrDevice s = {});
array layer_norm(
const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
StreamOrDevice s = {});
array rope(
const array& x,
int dims,
bool traditional,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
array rope(
const array& x,
int dims,
bool traditional,
std::optional<float> base,
float scale,
const array& offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/
array scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& values,
const float scale,
const std::string& mask_mode = "",
const std::vector<array>& mask_arrs = {},
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
const array& w,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
array affine_dequantize(
const array& w,
const array& scales,
const array& biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
typedef std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
std::vector<std::pair<std::string, TemplateArg>>,
std::optional<float>,
bool,
StreamOrDevice)>
MetalKernelFunction;
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header = "",
bool ensure_row_contiguous = true,
bool atomic_outputs = false);
} // namespace mlx::core::fast