* add test

* fix rope

* fix test
This commit is contained in:
Awni Hannun 2024-08-20 17:37:52 -07:00 committed by GitHub
parent bb1b76d9dc
commit d40e76809f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 17 deletions

View File

@ -22,23 +22,18 @@ void rope_single_impl(
float sintheta = metal::fast::sin(theta); float sintheta = metal::fast::sin(theta);
// Compute the input and output indices // Compute the input and output indices
uint in_index_1, in_index_2; uint index_1, index_2;
uint out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x + pos.y * stride; index_1 = 2 * pos.x + pos.y * stride;
out_index_2 = out_index_1 + 1; index_2 = index_1 + 1;
in_index_1 = 2 * pos.x + pos.y * stride;
in_index_2 = in_index_1 + 1;
} else { } else {
out_index_1 = pos.x + pos.y * stride; index_1 = pos.x + pos.y * stride;
out_index_2 = out_index_1 + grid.x; index_2 = index_1 + grid.x;
in_index_1 = pos.x + pos.y * stride;
in_index_2 = in_index_1 + grid.x;
} }
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[index_2]);
float rx1; float rx1;
float rx2; float rx2;
if (forward) { if (forward) {
@ -48,8 +43,8 @@ void rope_single_impl(
rx1 = x2 * sintheta + x1 * costheta; rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta; rx2 = x2 * costheta - x1 * sintheta;
} }
out[out_index_1] = static_cast<T>(rx1); out[index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2); out[index_2] = static_cast<T>(rx2);
} }
template <typename T, bool traditional, bool forward> template <typename T, bool traditional, bool forward>

View File

@ -86,7 +86,7 @@ void RoPE::eval_gpu(
MTL::Size group_dims; MTL::Size group_dims;
MTL::Size grid_dims; MTL::Size grid_dims;
if (single) { if (single) {
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4); compute_encoder->setBytes(out_strides, sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1); group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, n_batch, 1);

View File

@ -340,7 +340,7 @@ array rope(
if (inputs.size() == 2 && if (inputs.size() == 2 &&
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) { (inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] freqs must be one dimensional with size " << dims msg << "[rope] freqs must be one dimensional with size " << dims / 2
<< " but got shape " << inputs[1].shape() << "."; << " but got shape " << inputs[1].shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@ -8,7 +8,7 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
N = x.shape[1] + offset N = x.shape[-2] + offset
dtype = x.dtype dtype = x.dtype
half_D = dims // 2 half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale positions = mx.arange(offset, N, dtype=dtype) * scale
@ -143,6 +143,20 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test transpose into rope
dims, _, base, scale, offset, traditional = defaults
x = mx.random.uniform(shape=(1, 1, 4, dims)).swapaxes(1, 2)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
1.0 * x, # multiply here to allow donation
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32])
def test_rope_with_freqs(self): def test_rope_with_freqs(self):
# Check throws # Check throws
T = 4 T = 4