Allow offset to be an mx.array for mx.fast.rope (#1724)

* allow offset for rope

* comment
This commit is contained in:
Awni Hannun
2024-12-19 15:51:44 -08:00
committed by GitHub
parent c3628eea49
commit 0308e9af71
8 changed files with 97 additions and 52 deletions

View File

@@ -66,7 +66,7 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
@@ -78,7 +78,7 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(offset_, 2);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
@@ -101,7 +101,7 @@ void RoPE::eval_gpu(
}
if (with_freqs) {
auto& freqs = inputs[1];
auto& freqs = inputs[2];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder.set_bytes(freq_stride, 11);