Allow non-UTF-8 encoding in sbang hook (#26793)

Currently Spack reads full files containing shebangs to memory as
strings, meaning Spack would have to guess their encoding. Currently
Spack has a fixed guess of UTF-8.

This is unnecessary, since e.g. the Linux kernel does not assume an
encoding on paths at all, it's just bytes and some delimiters on the
byte level.

This commit does the following:

1. Shebangs are treated as bytes, so that e.g. latin1 encoded files do
not throw UnicodeEncoding errors, and adds a test for this.
2. No more bytes than necessary are read to memory, we only have to read
until the first newline, and from there on we an copy the file byte by
bytes instead of decoding and re-encoding text.
3. We cap the number of bytes read to 4096, if no newline is found
before that, we don't attempt to patch it.
4. Add support for luajit too.

This should make Spack both more efficient and usable for non-UTF8
files.
This commit is contained in:
Harmen Stoppels 2021-10-27 11:59:10 +02:00 committed by GitHub
parent 2fd87046cd
commit e04b172eb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 189 additions and 79 deletions

View File

@ -6,8 +6,10 @@
import filecmp import filecmp
import os import os
import re import re
import shutil
import stat import stat
import sys import sys
import tempfile
import llnl.util.filesystem as fs import llnl.util.filesystem as fs
import llnl.util.tty as tty import llnl.util.tty as tty
@ -19,9 +21,14 @@
#: Different Linux distributions have different limits, but 127 is the #: Different Linux distributions have different limits, but 127 is the
#: smallest among all modern versions. #: smallest among all modern versions.
if sys.platform == 'darwin': if sys.platform == 'darwin':
shebang_limit = 511 system_shebang_limit = 511
else: else:
shebang_limit = 127 system_shebang_limit = 127
#: Spack itself also limits the shebang line to at most 4KB, which should be plenty.
spack_shebang_limit = 4096
interpreter_regex = re.compile(b'#![ \t]*?([^ \t\0\n]+)')
def sbang_install_path(): def sbang_install_path():
@ -29,10 +36,10 @@ def sbang_install_path():
sbang_root = str(spack.store.unpadded_root) sbang_root = str(spack.store.unpadded_root)
install_path = os.path.join(sbang_root, "bin", "sbang") install_path = os.path.join(sbang_root, "bin", "sbang")
path_length = len(install_path) path_length = len(install_path)
if path_length > shebang_limit: if path_length > system_shebang_limit:
msg = ('Install tree root is too long. Spack cannot patch shebang lines' msg = ('Install tree root is too long. Spack cannot patch shebang lines'
' when script path length ({0}) exceeds limit ({1}).\n {2}') ' when script path length ({0}) exceeds limit ({1}).\n {2}')
msg = msg.format(path_length, shebang_limit, install_path) msg = msg.format(path_length, system_shebang_limit, install_path)
raise SbangPathError(msg) raise SbangPathError(msg)
return install_path return install_path
@ -49,71 +56,92 @@ def sbang_shebang_line():
return '#!/bin/sh %s' % sbang_install_path() return '#!/bin/sh %s' % sbang_install_path()
def shebang_too_long(path): def get_interpreter(binary_string):
"""Detects whether a file has a shebang line that is too long.""" # The interpreter may be preceded with ' ' and \t, is itself any byte that
if not os.path.isfile(path): # follows until the first occurrence of ' ', \t, \0, \n or end of file.
return False match = interpreter_regex.match(binary_string)
return None if match is None else match.group(1)
with open(path, 'rb') as script:
bytes = script.read(2)
if bytes != b'#!':
return False
line = bytes + script.readline()
return len(line) > shebang_limit
def filter_shebang(path): def filter_shebang(path):
"""Adds a second shebang line, using sbang, at the beginning of a file.""" """
with open(path, 'rb') as original_file: Adds a second shebang line, using sbang, at the beginning of a file, if necessary.
original = original_file.read() Note: Spack imposes a relaxed shebang line limit, meaning that a newline or end of
if sys.version_info >= (2, 7): file must occur before ``spack_shebang_limit`` bytes. If not, the file is not
original = original.decode(encoding='UTF-8') patched.
"""
with open(path, 'rb') as original:
# If there is no shebang, we shouldn't replace anything.
old_shebang_line = original.read(2)
if old_shebang_line != b'#!':
return False
# Stop reading after b'\n'. Note that old_shebang_line includes the first b'\n'.
old_shebang_line += original.readline(spack_shebang_limit - 2)
# If the shebang line is short, we don't have to do anything.
if len(old_shebang_line) <= system_shebang_limit:
return False
# Whenever we can't find a newline within the maximum number of bytes, we will
# not attempt to rewrite it. In principle we could still get the interpreter if
# only the arguments are truncated, but note that for PHP we need the full line
# since we have to append `?>` to it. Since our shebang limit is already very
# generous, it's unlikely to happen, and it should be fine to ignore.
if (
len(old_shebang_line) == spack_shebang_limit and
old_shebang_line[-1] != b'\n'
):
return False
# This line will be prepended to file
new_sbang_line = (sbang_shebang_line() + '\n').encode('utf-8')
# Skip files that are already using sbang.
if old_shebang_line == new_sbang_line:
return
interpreter = get_interpreter(old_shebang_line)
# If there was only whitespace we don't have to do anything.
if not interpreter:
return False
# Store the file permissions, the patched version needs the same.
saved_mode = os.stat(path).st_mode
# No need to delete since we'll move it and overwrite the original.
patched = tempfile.NamedTemporaryFile('wb', delete=False)
patched.write(new_sbang_line)
# Note that in Python this does not go out of bounds even if interpreter is a
# short byte array.
# Note: if the interpreter string was encoded with UTF-16, there would have
# been a \0 byte between all characters of lua, node, php; meaning that it would
# lead to truncation of the interpreter. So we don't have to worry about weird
# encodings here, and just looking at bytes is justified.
if interpreter[-4:] == b'/lua' or interpreter[-7:] == b'/luajit':
# Use --! instead of #! on second line for lua.
patched.write(b'--!' + old_shebang_line[2:])
elif interpreter[-5:] == b'/node':
# Use //! instead of #! on second line for node.js.
patched.write(b'//!' + old_shebang_line[2:])
elif interpreter[-4:] == b'/php':
# Use <?php #!... ?> instead of #!... on second line for php.
patched.write(b'<?php ' + old_shebang_line + b' ?>')
else: else:
original = original.decode('UTF-8') patched.write(old_shebang_line)
# This line will be prepended to file # After copying the remainder of the file, we can close the original
new_sbang_line = '%s\n' % sbang_shebang_line() shutil.copyfileobj(original, patched)
# Skip files that are already using sbang. # And close the temporary file so we can move it.
if original.startswith(new_sbang_line): patched.close()
return
# In the following, newlines have to be excluded in the regular expression # Overwrite original file with patched file, and keep the original mode
# else any mention of "lua" in the document will lead to spurious matches. shutil.move(patched.name, path)
os.chmod(path, saved_mode)
# Use --! instead of #! on second line for lua. return True
if re.search(r'^#!(/[^/\n]*)*lua\b', original):
original = re.sub(r'^#', '--', original)
# Use <?php #! instead of #! on second line for php.
if re.search(r'^#!(/[^/\n]*)*php\b', original):
original = re.sub(r'^#', '<?php #', original) + ' ?>'
# Use //! instead of #! on second line for node.js.
if re.search(r'^#!(/[^/\n]*)*node\b', original):
original = re.sub(r'^#', '//', original)
# Change non-writable files to be writable if needed.
saved_mode = None
if not os.access(path, os.W_OK):
st = os.stat(path)
saved_mode = st.st_mode
os.chmod(path, saved_mode | stat.S_IWRITE)
with open(path, 'wb') as new_file:
if sys.version_info >= (2, 7):
new_file.write(new_sbang_line.encode(encoding='UTF-8'))
new_file.write(original.encode(encoding='UTF-8'))
else:
new_file.write(new_sbang_line.encode('UTF-8'))
new_file.write(original.encode('UTF-8'))
# Restore original permissions.
if saved_mode is not None:
os.chmod(path, saved_mode)
tty.debug("Patched overlong shebang in %s" % path)
def filter_shebangs_in_directory(directory, filenames=None): def filter_shebangs_in_directory(directory, filenames=None):
@ -138,8 +166,8 @@ def filter_shebangs_in_directory(directory, filenames=None):
continue continue
# test the file for a long shebang, and filter # test the file for a long shebang, and filter
if shebang_too_long(path): if filter_shebang(path):
filter_shebang(path) tty.debug("Patched overlong shebang in %s" % path)
def install_sbang(): def install_sbang():

View File

@ -21,7 +21,7 @@
import spack.store import spack.store
from spack.util.executable import which from spack.util.executable import which
too_long = sbang.shebang_limit + 1 too_long = sbang.system_shebang_limit + 1
short_line = "#!/this/is/short/bin/bash\n" short_line = "#!/this/is/short/bin/bash\n"
@ -31,6 +31,10 @@
lua_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100) lua_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
lua_line_patched = "--!/this/" + ('x' * too_long) + "/is/lua\n" lua_line_patched = "--!/this/" + ('x' * too_long) + "/is/lua\n"
luajit_line = "#!/this/" + ('x' * too_long) + "/is/luajit\n"
luajit_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
luajit_line_patched = "--!/this/" + ('x' * too_long) + "/is/luajit\n"
node_line = "#!/this/" + ('x' * too_long) + "/is/node\n" node_line = "#!/this/" + ('x' * too_long) + "/is/node\n"
node_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100) node_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
node_line_patched = "//!/this/" + ('x' * too_long) + "/is/node\n" node_line_patched = "//!/this/" + ('x' * too_long) + "/is/node\n"
@ -84,7 +88,7 @@ def __init__(self, sbang_line):
f.write(last_line) f.write(last_line)
self.make_executable(self.lua_shebang) self.make_executable(self.lua_shebang)
# Lua script with long shebang # Lua occurring in text, not in shebang
self.lua_textbang = os.path.join(self.tempdir, 'lua_in_text') self.lua_textbang = os.path.join(self.tempdir, 'lua_in_text')
with open(self.lua_textbang, 'w') as f: with open(self.lua_textbang, 'w') as f:
f.write(short_line) f.write(short_line)
@ -92,6 +96,21 @@ def __init__(self, sbang_line):
f.write(last_line) f.write(last_line)
self.make_executable(self.lua_textbang) self.make_executable(self.lua_textbang)
# Luajit script with long shebang
self.luajit_shebang = os.path.join(self.tempdir, 'luajit')
with open(self.luajit_shebang, 'w') as f:
f.write(luajit_line)
f.write(last_line)
self.make_executable(self.luajit_shebang)
# Luajit occuring in text, not in shebang
self.luajit_textbang = os.path.join(self.tempdir, 'luajit_in_text')
with open(self.luajit_textbang, 'w') as f:
f.write(short_line)
f.write(luajit_in_text)
f.write(last_line)
self.make_executable(self.luajit_textbang)
# Node script with long shebang # Node script with long shebang
self.node_shebang = os.path.join(self.tempdir, 'node') self.node_shebang = os.path.join(self.tempdir, 'node')
with open(self.node_shebang, 'w') as f: with open(self.node_shebang, 'w') as f:
@ -99,7 +118,7 @@ def __init__(self, sbang_line):
f.write(last_line) f.write(last_line)
self.make_executable(self.node_shebang) self.make_executable(self.node_shebang)
# Node script with long shebang # Node occuring in text, not in shebang
self.node_textbang = os.path.join(self.tempdir, 'node_in_text') self.node_textbang = os.path.join(self.tempdir, 'node_in_text')
with open(self.node_textbang, 'w') as f: with open(self.node_textbang, 'w') as f:
f.write(short_line) f.write(short_line)
@ -114,7 +133,7 @@ def __init__(self, sbang_line):
f.write(last_line) f.write(last_line)
self.make_executable(self.php_shebang) self.make_executable(self.php_shebang)
# php script with long shebang # php occuring in text, not in shebang
self.php_textbang = os.path.join(self.tempdir, 'php_in_text') self.php_textbang = os.path.join(self.tempdir, 'php_in_text')
with open(self.php_textbang, 'w') as f: with open(self.php_textbang, 'w') as f:
f.write(short_line) f.write(short_line)
@ -157,16 +176,22 @@ def script_dir(sbang_line):
sdir.destroy() sdir.destroy()
@pytest.mark.parametrize('shebang,interpreter', [
(b'#!/path/to/interpreter argument\n', b'/path/to/interpreter'),
(b'#! /path/to/interpreter truncated-argum', b'/path/to/interpreter'),
(b'#! \t \t/path/to/interpreter\t \targument', b'/path/to/interpreter'),
(b'#! \t \t /path/to/interpreter', b'/path/to/interpreter'),
(b'#!/path/to/interpreter\0', b'/path/to/interpreter'),
(b'#!/path/to/interpreter multiple args\n', b'/path/to/interpreter'),
(b'#!\0/path/to/interpreter arg\n', None),
(b'#!\n/path/to/interpreter arg\n', None),
(b'#!', None)
])
def test_shebang_interpreter_regex(shebang, interpreter):
sbang.get_interpreter(shebang) == interpreter
def test_shebang_handling(script_dir, sbang_line): def test_shebang_handling(script_dir, sbang_line):
assert sbang.shebang_too_long(script_dir.lua_shebang)
assert sbang.shebang_too_long(script_dir.long_shebang)
assert sbang.shebang_too_long(script_dir.nonexec_long_shebang)
assert not sbang.shebang_too_long(script_dir.short_shebang)
assert not sbang.shebang_too_long(script_dir.has_sbang)
assert not sbang.shebang_too_long(script_dir.binary)
assert not sbang.shebang_too_long(script_dir.directory)
sbang.filter_shebangs_in_directory(script_dir.tempdir) sbang.filter_shebangs_in_directory(script_dir.tempdir)
# Make sure this is untouched # Make sure this is untouched
@ -191,6 +216,12 @@ def test_shebang_handling(script_dir, sbang_line):
assert f.readline() == lua_line_patched assert f.readline() == lua_line_patched
assert f.readline() == last_line assert f.readline() == last_line
# Make sure this got patched.
with open(script_dir.luajit_shebang, 'r') as f:
assert f.readline() == sbang_line
assert f.readline() == luajit_line_patched
assert f.readline() == last_line
# Make sure this got patched. # Make sure this got patched.
with open(script_dir.node_shebang, 'r') as f: with open(script_dir.node_shebang, 'r') as f:
assert f.readline() == sbang_line assert f.readline() == sbang_line
@ -199,8 +230,12 @@ def test_shebang_handling(script_dir, sbang_line):
assert filecmp.cmp(script_dir.lua_textbang, assert filecmp.cmp(script_dir.lua_textbang,
os.path.join(script_dir.tempdir, 'lua_in_text')) os.path.join(script_dir.tempdir, 'lua_in_text'))
assert filecmp.cmp(script_dir.luajit_textbang,
os.path.join(script_dir.tempdir, 'luajit_in_text'))
assert filecmp.cmp(script_dir.node_textbang, assert filecmp.cmp(script_dir.node_textbang,
os.path.join(script_dir.tempdir, 'node_in_text')) os.path.join(script_dir.tempdir, 'node_in_text'))
assert filecmp.cmp(script_dir.php_textbang,
os.path.join(script_dir.tempdir, 'php_in_text'))
# Make sure this is untouched # Make sure this is untouched
with open(script_dir.has_sbang, 'r') as f: with open(script_dir.has_sbang, 'r') as f:
@ -261,7 +296,7 @@ def test_install_sbang(install_mockery):
def test_install_sbang_too_long(tmpdir): def test_install_sbang_too_long(tmpdir):
root = str(tmpdir) root = str(tmpdir)
num_extend = sbang.shebang_limit - len(root) - len('/bin/sbang') num_extend = sbang.system_shebang_limit - len(root) - len('/bin/sbang')
long_path = root long_path = root
while num_extend > 1: while num_extend > 1:
add = min(num_extend, 255) add = min(num_extend, 255)
@ -282,7 +317,7 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir):
# consisting of invalid UTF-8. The latter is technically not really necessary for # consisting of invalid UTF-8. The latter is technically not really necessary for
# the test, but binary blobs accidentally starting with b'#!' usually do not contain # the test, but binary blobs accidentally starting with b'#!' usually do not contain
# valid UTF-8, so we also ensure that Spack does not attempt to decode as UTF-8. # valid UTF-8, so we also ensure that Spack does not attempt to decode as UTF-8.
contents = b'#!' + b'\x80' * sbang.shebang_limit contents = b'#!' + b'\x80' * sbang.system_shebang_limit
file = str(tmpdir.join('non-executable.sh')) file = str(tmpdir.join('non-executable.sh'))
with open(file, 'wb') as f: with open(file, 'wb') as f:
f.write(contents) f.write(contents)
@ -292,3 +327,50 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir):
# Make sure there is no sbang shebang. # Make sure there is no sbang shebang.
with open(file, 'rb') as f: with open(file, 'rb') as f:
assert b'sbang' not in f.readline() assert b'sbang' not in f.readline()
def test_sbang_handles_non_utf8_files(tmpdir):
# We have an executable with a copyright sign as filename
contents = (b'#!' + b'\xa9' * sbang.system_shebang_limit +
b'\nand another symbol: \xa9')
# Make sure it's indeed valid latin1 but invalid utf-8.
assert contents.decode('latin1')
with pytest.raises(UnicodeDecodeError):
contents.decode('utf-8')
# Put it in an executable file
file = str(tmpdir.join('latin1.sh'))
with open(file, 'wb') as f:
f.write(contents)
# Run sbang
assert sbang.filter_shebang(file)
with open(file, 'rb') as f:
new_contents = f.read()
assert contents in new_contents
assert b'sbang' in new_contents
@pytest.fixture
def shebang_limits_system_8_spack_16():
system_limit, sbang.system_shebang_limit = sbang.system_shebang_limit, 8
spack_limit, sbang.spack_shebang_limit = sbang.spack_shebang_limit, 16
yield
sbang.system_shebang_limit = system_limit
sbang.spack_shebang_limit = spack_limit
def test_shebang_exceeds_spack_shebang_limit(shebang_limits_system_8_spack_16, tmpdir):
"""Tests whether shebangs longer than Spack's limit are skipped"""
file = str(tmpdir.join('longer_than_spack_limit.sh'))
with open(file, 'wb') as f:
f.write(b'#!' + b'x' * sbang.spack_shebang_limit)
# Then Spack shouldn't try to add a shebang
assert not sbang.filter_shebang(file)
with open(file, 'rb') as f:
assert b'sbang' not in f.read()