mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Allow offset to be an mx.array for mx.fast.rope
(#1724)
* allow offset for rope * comment
This commit is contained in:
@@ -66,7 +66,7 @@ void RoPE::eval_gpu(
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
|
||||
bool with_freqs = inputs.size() == 2;
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (single ? "single_" : "")
|
||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||
@@ -78,7 +78,7 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(offset_, 2);
|
||||
compute_encoder.set_input_array(inputs[1], 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
@@ -101,7 +101,7 @@ void RoPE::eval_gpu(
|
||||
}
|
||||
|
||||
if (with_freqs) {
|
||||
auto& freqs = inputs[1];
|
||||
auto& freqs = inputs[2];
|
||||
compute_encoder.set_input_array(freqs, 10);
|
||||
auto freq_stride = freqs.strides()[0];
|
||||
compute_encoder.set_bytes(freq_stride, 11);
|
||||
|
Reference in New Issue
Block a user