mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
5 Commits
704fd1ae28
...
jit-nax
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cf6f10bef | ||
|
|
7c1abc50c0 | ||
|
|
2b95d0c270 | ||
|
|
b054838780 | ||
|
|
dd79d3c465 |
@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
auto warp = cg::tiled_partition<GROUP_DIM>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
if constexpr (BLOCK_DIM > GROUP_DIM) {
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_small(
|
||||
const T* x,
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
uint32_t axis_size,
|
||||
uint32_t n_rows,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
auto row =
|
||||
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
x += row * axis_size;
|
||||
out += row * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
auto index = block.thread_index().x;
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
}
|
||||
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float y = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(y);
|
||||
}
|
||||
store_vector<N_READS>(out, index, xn, axis_size);
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm(
|
||||
const T* x,
|
||||
@@ -94,6 +142,74 @@ __global__ void rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
bool HAS_W,
|
||||
int BLOCK_DIM,
|
||||
int REDUCE_DIM,
|
||||
int N_READS = 4>
|
||||
__global__ void rms_norm_vjp_small(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int32_t n_rows,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
|
||||
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||
|
||||
auto row =
|
||||
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
x += row * axis_size;
|
||||
g += row * axis_size;
|
||||
gx += row * axis_size;
|
||||
gw += row * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float2 factors = {};
|
||||
auto index = block.thread_index().x;
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f2(factors, {wg * t, t * t});
|
||||
}
|
||||
|
||||
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||
float meangwx = factors.x / axis_size;
|
||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Outputs.
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_vjp(
|
||||
const T* x,
|
||||
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF2::TempStorage f2;
|
||||
} temp;
|
||||
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
|
||||
factors = plus_f2(factors, {wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||
float meangwx = factors.x / axis_size;
|
||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
template <int n_per_thread, typename F>
|
||||
void dispatch_group_dim(int axis_size, F&& f) {
|
||||
if (axis_size <= n_per_thread * 8) {
|
||||
f(std::integral_constant<int, 8>{},
|
||||
std::integral_constant<int, 1>(),
|
||||
std::integral_constant<int, 16>());
|
||||
} else if (axis_size <= n_per_thread * 16) {
|
||||
f(std::integral_constant<int, 16>{},
|
||||
std::integral_constant<int, 1>(),
|
||||
std::integral_constant<int, 8>());
|
||||
} else if (axis_size <= n_per_thread * 32) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 1>(),
|
||||
std::integral_constant<int, 4>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 2) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 2>(),
|
||||
std::integral_constant<int, 2>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 4) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 4>(),
|
||||
std::integral_constant<int, 1>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 8) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 8>(),
|
||||
std::integral_constant<int, 1>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 16) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 16>(),
|
||||
std::integral_constant<int, 1>());
|
||||
} else {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 32>(),
|
||||
std::integral_constant<int, 1>());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
|
||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||
if (axis_size <= N_READS * 1024) {
|
||||
dispatch_group_dim<N_READS>(
|
||||
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||
constexpr int block_dim = n_groups() * group_dim();
|
||||
auto kernel =
|
||||
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
|
||||
auto n_blocks =
|
||||
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_blocks,
|
||||
{block_dim, groups_per_block()},
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(out),
|
||||
eps_,
|
||||
axis_size,
|
||||
n_rows,
|
||||
w_stride);
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
1024,
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
if (axis_size <= N_READS * 1024) {
|
||||
dispatch_group_dim<N_READS>(
|
||||
axis_size,
|
||||
[&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||
constexpr int block_dim = group_dim() * n_groups();
|
||||
auto kernel = cu::rms_norm_vjp_small<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
block_dim,
|
||||
group_dim(),
|
||||
N_READS>;
|
||||
auto n_blocks =
|
||||
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_blocks,
|
||||
{block_dim, groups_per_block()},
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
n_rows,
|
||||
w_stride);
|
||||
});
|
||||
} else {
|
||||
auto kernel =
|
||||
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
1024,
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -144,13 +144,13 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
|
||||
return cache;
|
||||
}
|
||||
|
||||
auto& sdpa_backward_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
|
||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
|
||||
return cache;
|
||||
}
|
||||
|
||||
@@ -207,8 +207,14 @@ fe::graph::Graph build_sdpa_graph(
|
||||
auto options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(output_logsumexp);
|
||||
if (do_causal) {
|
||||
if (q.shape(2) > k.shape(2)) {
|
||||
options.set_causal_mask(do_causal);
|
||||
} else {
|
||||
options.set_causal_mask_bottom_right(do_causal);
|
||||
}
|
||||
}
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
@@ -282,7 +288,14 @@ fe::graph::Graph build_sdpa_backward_graph(
|
||||
auto options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("sdpa_backward_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal);
|
||||
.set_attn_scale(scale);
|
||||
if (do_causal) {
|
||||
if (q.shape(2) > k.shape(2)) {
|
||||
options.set_causal_mask(do_causal);
|
||||
} else {
|
||||
options.set_causal_mask_bottom_right(do_causal);
|
||||
}
|
||||
}
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
@@ -340,6 +353,7 @@ bool supports_sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||
if (!enabled) {
|
||||
@@ -351,8 +365,8 @@ bool supports_sdpa_cudnn(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only use cuDNN for prefilling and training.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
|
||||
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -520,7 +534,7 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
|
||||
return !supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||
!supports_sdpa_cudnn(q, k, v, s);
|
||||
!supports_sdpa_cudnn(q, k, v, do_causal, s);
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||
|
||||
@@ -16,7 +16,77 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
|
||||
CCC="xcrun -sdk macosx metal -x metal"
|
||||
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
declare -a HDRS_LIST=($HDRS)
|
||||
declare -a HDRS_STACK=()
|
||||
declare -a HDRS_SORTED=()
|
||||
|
||||
length=${#HDRS_LIST[@]}
|
||||
|
||||
HDRS_LIST+=(".")
|
||||
|
||||
for ((i=0; i<${length}; i+=2));
|
||||
do
|
||||
|
||||
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
|
||||
|
||||
str_this="${HDRS_LIST[$i]}"
|
||||
str_next="${HDRS_LIST[$i + 2]}"
|
||||
|
||||
depth_this=${#str_this}
|
||||
depth_next=${#str_next}
|
||||
|
||||
# If we have a dependency then we stack it
|
||||
if [ $depth_next -gt $depth_this ]; then
|
||||
HDRS_STACK=($header ${HDRS_STACK[@]})
|
||||
|
||||
# If we are done with this level
|
||||
else
|
||||
# We add the header to out list
|
||||
HDRS_SORTED+=($header)
|
||||
|
||||
# Pop the stacked up dependencies
|
||||
pop_len=$((depth_this - depth_next))
|
||||
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
|
||||
do
|
||||
HDRS_SORTED+=($popped_header)
|
||||
done
|
||||
|
||||
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
|
||||
|
||||
CONTENT=$(
|
||||
echo "// Copyright © 2025 Apple Inc."
|
||||
echo ""
|
||||
echo "// Auto generated source for $INPUT_FILE"
|
||||
echo ""
|
||||
|
||||
for header in "${HDRS_SORTED[@]}"
|
||||
do
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
echo "// Contents from \"${header}\""
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
echo ""
|
||||
|
||||
echo "#line 1 \"${header}\""
|
||||
|
||||
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
|
||||
|
||||
echo ""
|
||||
|
||||
done
|
||||
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -407,7 +407,10 @@ class Module(dict):
|
||||
instance).
|
||||
|
||||
Args:
|
||||
apply_fn (Callable): The function to apply to the modules.
|
||||
apply_fn (Callable): The function to apply to the modules which
|
||||
takes two parameters. The first parameter is the string path of
|
||||
the module (e.g. ``"model.layers.0.linear"``). The second
|
||||
parameter is the module object.
|
||||
|
||||
Returns:
|
||||
The module instance after updating submodules.
|
||||
|
||||
@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
|
||||
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
|
||||
|
||||
Returns:
|
||||
list(array): A list of the Jacobian-vector products which
|
||||
is the same in number, shape, and type of the inputs to ``fun``.
|
||||
tuple(list(array), list(array)): A tuple with the outputs of
|
||||
``fun`` in the first position and the Jacobian-vector products
|
||||
in the second position.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vjp",
|
||||
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
|
||||
same in number, shape, and type as the outputs of ``fun``.
|
||||
|
||||
Returns:
|
||||
list(array): A list of the vector-Jacobian products which
|
||||
is the same in number, shape, and type of the outputs of ``fun``.
|
||||
tuple(list(array), list(array)): A tuple with the outputs of
|
||||
``fun`` in the first position and the vector-Jacobian products
|
||||
in the second position.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"value_and_grad",
|
||||
|
||||
Reference in New Issue
Block a user