Allow non-UTF-8 encoding in sbang hook (#26793)

Currently Spack reads full files containing shebangs to memory as
strings, meaning Spack would have to guess their encoding. Currently
Spack has a fixed guess of UTF-8.

This is unnecessary, since e.g. the Linux kernel does not assume an
encoding on paths at all, it's just bytes and some delimiters on the
byte level.

This commit does the following:

1. Shebangs are treated as bytes, so that e.g. latin1 encoded files do
not throw UnicodeEncoding errors, and adds a test for this.
2. No more bytes than necessary are read to memory, we only have to read
until the first newline, and from there on we an copy the file byte by
bytes instead of decoding and re-encoding text.
3. We cap the number of bytes read to 4096, if no newline is found
before that, we don't attempt to patch it.
4. Add support for luajit too.

This should make Spack both more efficient and usable for non-UTF8
files.
This commit is contained in:
Harmen Stoppels 2021-10-27 11:59:10 +02:00 committed by GitHub
parent 2fd87046cd
commit e04b172eb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 189 additions and 79 deletions

View File

@ -6,8 +6,10 @@
import filecmp
import os
import re
import shutil
import stat
import sys
import tempfile
import llnl.util.filesystem as fs
import llnl.util.tty as tty
@ -19,9 +21,14 @@
#: Different Linux distributions have different limits, but 127 is the
#: smallest among all modern versions.
if sys.platform == 'darwin':
shebang_limit = 511
system_shebang_limit = 511
else:
shebang_limit = 127
system_shebang_limit = 127
#: Spack itself also limits the shebang line to at most 4KB, which should be plenty.
spack_shebang_limit = 4096
interpreter_regex = re.compile(b'#![ \t]*?([^ \t\0\n]+)')
def sbang_install_path():
@ -29,10 +36,10 @@ def sbang_install_path():
sbang_root = str(spack.store.unpadded_root)
install_path = os.path.join(sbang_root, "bin", "sbang")
path_length = len(install_path)
if path_length > shebang_limit:
if path_length > system_shebang_limit:
msg = ('Install tree root is too long. Spack cannot patch shebang lines'
' when script path length ({0}) exceeds limit ({1}).\n {2}')
msg = msg.format(path_length, shebang_limit, install_path)
msg = msg.format(path_length, system_shebang_limit, install_path)
raise SbangPathError(msg)
return install_path
@ -49,71 +56,92 @@ def sbang_shebang_line():
return '#!/bin/sh %s' % sbang_install_path()
def shebang_too_long(path):
"""Detects whether a file has a shebang line that is too long."""
if not os.path.isfile(path):
return False
with open(path, 'rb') as script:
bytes = script.read(2)
if bytes != b'#!':
return False
line = bytes + script.readline()
return len(line) > shebang_limit
def get_interpreter(binary_string):
# The interpreter may be preceded with ' ' and \t, is itself any byte that
# follows until the first occurrence of ' ', \t, \0, \n or end of file.
match = interpreter_regex.match(binary_string)
return None if match is None else match.group(1)
def filter_shebang(path):
"""Adds a second shebang line, using sbang, at the beginning of a file."""
with open(path, 'rb') as original_file:
original = original_file.read()
if sys.version_info >= (2, 7):
original = original.decode(encoding='UTF-8')
"""
Adds a second shebang line, using sbang, at the beginning of a file, if necessary.
Note: Spack imposes a relaxed shebang line limit, meaning that a newline or end of
file must occur before ``spack_shebang_limit`` bytes. If not, the file is not
patched.
"""
with open(path, 'rb') as original:
# If there is no shebang, we shouldn't replace anything.
old_shebang_line = original.read(2)
if old_shebang_line != b'#!':
return False
# Stop reading after b'\n'. Note that old_shebang_line includes the first b'\n'.
old_shebang_line += original.readline(spack_shebang_limit - 2)
# If the shebang line is short, we don't have to do anything.
if len(old_shebang_line) <= system_shebang_limit:
return False
# Whenever we can't find a newline within the maximum number of bytes, we will
# not attempt to rewrite it. In principle we could still get the interpreter if
# only the arguments are truncated, but note that for PHP we need the full line
# since we have to append `?>` to it. Since our shebang limit is already very
# generous, it's unlikely to happen, and it should be fine to ignore.
if (
len(old_shebang_line) == spack_shebang_limit and
old_shebang_line[-1] != b'\n'
):
return False
# This line will be prepended to file
new_sbang_line = (sbang_shebang_line() + '\n').encode('utf-8')
# Skip files that are already using sbang.
if old_shebang_line == new_sbang_line:
return
interpreter = get_interpreter(old_shebang_line)
# If there was only whitespace we don't have to do anything.
if not interpreter:
return False
# Store the file permissions, the patched version needs the same.
saved_mode = os.stat(path).st_mode
# No need to delete since we'll move it and overwrite the original.
patched = tempfile.NamedTemporaryFile('wb', delete=False)
patched.write(new_sbang_line)
# Note that in Python this does not go out of bounds even if interpreter is a
# short byte array.
# Note: if the interpreter string was encoded with UTF-16, there would have
# been a \0 byte between all characters of lua, node, php; meaning that it would
# lead to truncation of the interpreter. So we don't have to worry about weird
# encodings here, and just looking at bytes is justified.
if interpreter[-4:] == b'/lua' or interpreter[-7:] == b'/luajit':
# Use --! instead of #! on second line for lua.
patched.write(b'--!' + old_shebang_line[2:])
elif interpreter[-5:] == b'/node':
# Use //! instead of #! on second line for node.js.
patched.write(b'//!' + old_shebang_line[2:])
elif interpreter[-4:] == b'/php':
# Use <?php #!... ?> instead of #!... on second line for php.
patched.write(b'<?php ' + old_shebang_line + b' ?>')
else:
original = original.decode('UTF-8')
patched.write(old_shebang_line)
# This line will be prepended to file
new_sbang_line = '%s\n' % sbang_shebang_line()
# After copying the remainder of the file, we can close the original
shutil.copyfileobj(original, patched)
# Skip files that are already using sbang.
if original.startswith(new_sbang_line):
return
# And close the temporary file so we can move it.
patched.close()
# In the following, newlines have to be excluded in the regular expression
# else any mention of "lua" in the document will lead to spurious matches.
# Use --! instead of #! on second line for lua.
if re.search(r'^#!(/[^/\n]*)*lua\b', original):
original = re.sub(r'^#', '--', original)
# Use <?php #! instead of #! on second line for php.
if re.search(r'^#!(/[^/\n]*)*php\b', original):
original = re.sub(r'^#', '<?php #', original) + ' ?>'
# Use //! instead of #! on second line for node.js.
if re.search(r'^#!(/[^/\n]*)*node\b', original):
original = re.sub(r'^#', '//', original)
# Change non-writable files to be writable if needed.
saved_mode = None
if not os.access(path, os.W_OK):
st = os.stat(path)
saved_mode = st.st_mode
os.chmod(path, saved_mode | stat.S_IWRITE)
with open(path, 'wb') as new_file:
if sys.version_info >= (2, 7):
new_file.write(new_sbang_line.encode(encoding='UTF-8'))
new_file.write(original.encode(encoding='UTF-8'))
else:
new_file.write(new_sbang_line.encode('UTF-8'))
new_file.write(original.encode('UTF-8'))
# Restore original permissions.
if saved_mode is not None:
os.chmod(path, saved_mode)
tty.debug("Patched overlong shebang in %s" % path)
# Overwrite original file with patched file, and keep the original mode
shutil.move(patched.name, path)
os.chmod(path, saved_mode)
return True
def filter_shebangs_in_directory(directory, filenames=None):
@ -138,8 +166,8 @@ def filter_shebangs_in_directory(directory, filenames=None):
continue
# test the file for a long shebang, and filter
if shebang_too_long(path):
filter_shebang(path)
if filter_shebang(path):
tty.debug("Patched overlong shebang in %s" % path)
def install_sbang():

View File

@ -21,7 +21,7 @@
import spack.store
from spack.util.executable import which
too_long = sbang.shebang_limit + 1
too_long = sbang.system_shebang_limit + 1
short_line = "#!/this/is/short/bin/bash\n"
@ -31,6 +31,10 @@
lua_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
lua_line_patched = "--!/this/" + ('x' * too_long) + "/is/lua\n"
luajit_line = "#!/this/" + ('x' * too_long) + "/is/luajit\n"
luajit_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
luajit_line_patched = "--!/this/" + ('x' * too_long) + "/is/luajit\n"
node_line = "#!/this/" + ('x' * too_long) + "/is/node\n"
node_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
node_line_patched = "//!/this/" + ('x' * too_long) + "/is/node\n"
@ -84,7 +88,7 @@ def __init__(self, sbang_line):
f.write(last_line)
self.make_executable(self.lua_shebang)
# Lua script with long shebang
# Lua occurring in text, not in shebang
self.lua_textbang = os.path.join(self.tempdir, 'lua_in_text')
with open(self.lua_textbang, 'w') as f:
f.write(short_line)
@ -92,6 +96,21 @@ def __init__(self, sbang_line):
f.write(last_line)
self.make_executable(self.lua_textbang)
# Luajit script with long shebang
self.luajit_shebang = os.path.join(self.tempdir, 'luajit')
with open(self.luajit_shebang, 'w') as f:
f.write(luajit_line)
f.write(last_line)
self.make_executable(self.luajit_shebang)
# Luajit occuring in text, not in shebang
self.luajit_textbang = os.path.join(self.tempdir, 'luajit_in_text')
with open(self.luajit_textbang, 'w') as f:
f.write(short_line)
f.write(luajit_in_text)
f.write(last_line)
self.make_executable(self.luajit_textbang)
# Node script with long shebang
self.node_shebang = os.path.join(self.tempdir, 'node')
with open(self.node_shebang, 'w') as f:
@ -99,7 +118,7 @@ def __init__(self, sbang_line):
f.write(last_line)
self.make_executable(self.node_shebang)
# Node script with long shebang
# Node occuring in text, not in shebang
self.node_textbang = os.path.join(self.tempdir, 'node_in_text')
with open(self.node_textbang, 'w') as f:
f.write(short_line)
@ -114,7 +133,7 @@ def __init__(self, sbang_line):
f.write(last_line)
self.make_executable(self.php_shebang)
# php script with long shebang
# php occuring in text, not in shebang
self.php_textbang = os.path.join(self.tempdir, 'php_in_text')
with open(self.php_textbang, 'w') as f:
f.write(short_line)
@ -157,16 +176,22 @@ def script_dir(sbang_line):
sdir.destroy()
@pytest.mark.parametrize('shebang,interpreter', [
(b'#!/path/to/interpreter argument\n', b'/path/to/interpreter'),
(b'#! /path/to/interpreter truncated-argum', b'/path/to/interpreter'),
(b'#! \t \t/path/to/interpreter\t \targument', b'/path/to/interpreter'),
(b'#! \t \t /path/to/interpreter', b'/path/to/interpreter'),
(b'#!/path/to/interpreter\0', b'/path/to/interpreter'),
(b'#!/path/to/interpreter multiple args\n', b'/path/to/interpreter'),
(b'#!\0/path/to/interpreter arg\n', None),
(b'#!\n/path/to/interpreter arg\n', None),
(b'#!', None)
])
def test_shebang_interpreter_regex(shebang, interpreter):
sbang.get_interpreter(shebang) == interpreter
def test_shebang_handling(script_dir, sbang_line):
assert sbang.shebang_too_long(script_dir.lua_shebang)
assert sbang.shebang_too_long(script_dir.long_shebang)
assert sbang.shebang_too_long(script_dir.nonexec_long_shebang)
assert not sbang.shebang_too_long(script_dir.short_shebang)
assert not sbang.shebang_too_long(script_dir.has_sbang)
assert not sbang.shebang_too_long(script_dir.binary)
assert not sbang.shebang_too_long(script_dir.directory)
sbang.filter_shebangs_in_directory(script_dir.tempdir)
# Make sure this is untouched
@ -191,6 +216,12 @@ def test_shebang_handling(script_dir, sbang_line):
assert f.readline() == lua_line_patched
assert f.readline() == last_line
# Make sure this got patched.
with open(script_dir.luajit_shebang, 'r') as f:
assert f.readline() == sbang_line
assert f.readline() == luajit_line_patched
assert f.readline() == last_line
# Make sure this got patched.
with open(script_dir.node_shebang, 'r') as f:
assert f.readline() == sbang_line
@ -199,8 +230,12 @@ def test_shebang_handling(script_dir, sbang_line):
assert filecmp.cmp(script_dir.lua_textbang,
os.path.join(script_dir.tempdir, 'lua_in_text'))
assert filecmp.cmp(script_dir.luajit_textbang,
os.path.join(script_dir.tempdir, 'luajit_in_text'))
assert filecmp.cmp(script_dir.node_textbang,
os.path.join(script_dir.tempdir, 'node_in_text'))
assert filecmp.cmp(script_dir.php_textbang,
os.path.join(script_dir.tempdir, 'php_in_text'))
# Make sure this is untouched
with open(script_dir.has_sbang, 'r') as f:
@ -261,7 +296,7 @@ def test_install_sbang(install_mockery):
def test_install_sbang_too_long(tmpdir):
root = str(tmpdir)
num_extend = sbang.shebang_limit - len(root) - len('/bin/sbang')
num_extend = sbang.system_shebang_limit - len(root) - len('/bin/sbang')
long_path = root
while num_extend > 1:
add = min(num_extend, 255)
@ -282,7 +317,7 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir):
# consisting of invalid UTF-8. The latter is technically not really necessary for
# the test, but binary blobs accidentally starting with b'#!' usually do not contain
# valid UTF-8, so we also ensure that Spack does not attempt to decode as UTF-8.
contents = b'#!' + b'\x80' * sbang.shebang_limit
contents = b'#!' + b'\x80' * sbang.system_shebang_limit
file = str(tmpdir.join('non-executable.sh'))
with open(file, 'wb') as f:
f.write(contents)
@ -292,3 +327,50 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir):
# Make sure there is no sbang shebang.
with open(file, 'rb') as f:
assert b'sbang' not in f.readline()
def test_sbang_handles_non_utf8_files(tmpdir):
# We have an executable with a copyright sign as filename
contents = (b'#!' + b'\xa9' * sbang.system_shebang_limit +
b'\nand another symbol: \xa9')
# Make sure it's indeed valid latin1 but invalid utf-8.
assert contents.decode('latin1')
with pytest.raises(UnicodeDecodeError):
contents.decode('utf-8')
# Put it in an executable file
file = str(tmpdir.join('latin1.sh'))
with open(file, 'wb') as f:
f.write(contents)
# Run sbang
assert sbang.filter_shebang(file)
with open(file, 'rb') as f:
new_contents = f.read()
assert contents in new_contents
assert b'sbang' in new_contents
@pytest.fixture
def shebang_limits_system_8_spack_16():
system_limit, sbang.system_shebang_limit = sbang.system_shebang_limit, 8
spack_limit, sbang.spack_shebang_limit = sbang.spack_shebang_limit, 16
yield
sbang.system_shebang_limit = system_limit
sbang.spack_shebang_limit = spack_limit
def test_shebang_exceeds_spack_shebang_limit(shebang_limits_system_8_spack_16, tmpdir):
"""Tests whether shebangs longer than Spack's limit are skipped"""
file = str(tmpdir.join('longer_than_spack_limit.sh'))
with open(file, 'wb') as f:
f.write(b'#!' + b'x' * sbang.spack_shebang_limit)
# Then Spack shouldn't try to add a shebang
assert not sbang.filter_shebang(file)
with open(file, 'rb') as f:
assert b'sbang' not in f.read()