filter_file: fix multiple invocations on the same file (#13234)

Since the backup file is only created on the first invocation, it will
contain the original file without any modifications. Further invocations
will then read the backup file, effectively reverting prior invocations.

This can be reproduced easily by trying to install likwid, which will
try to install into /usr/local. Work around this by creating a temporary
file to read from.
This commit is contained in:
Michael Kuhn 2019-10-17 00:15:24 +02:00 committed by Todd Gamblin
parent 1ef71376f2
commit ffe87ed49f
2 changed files with 33 additions and 2 deletions

View File

@ -153,6 +153,7 @@ def groupid_to_group(x):
tty.debug(msg.format(filename, regex)) tty.debug(msg.format(filename, regex))
backup_filename = filename + "~" backup_filename = filename + "~"
tmp_filename = filename + ".spack~"
if ignore_absent and not os.path.exists(filename): if ignore_absent and not os.path.exists(filename):
msg = 'FILTER FILE: file "{0}" not found. Skipping to next file.' msg = 'FILTER FILE: file "{0}" not found. Skipping to next file.'
@ -164,6 +165,10 @@ def groupid_to_group(x):
if not os.path.exists(backup_filename): if not os.path.exists(backup_filename):
shutil.copy(filename, backup_filename) shutil.copy(filename, backup_filename)
# Create a temporary file to read from. We cannot use backup_filename
# in case filter_file is invoked multiple times on the same file.
shutil.copy(filename, tmp_filename)
try: try:
extra_kwargs = {} extra_kwargs = {}
if sys.version_info > (3, 0): if sys.version_info > (3, 0):
@ -171,7 +176,7 @@ def groupid_to_group(x):
# Open as a text file and filter until the end of the file is # Open as a text file and filter until the end of the file is
# reached or we found a marker in the line if it was specified # reached or we found a marker in the line if it was specified
with open(backup_filename, mode='r', **extra_kwargs) as input_file: with open(tmp_filename, mode='r', **extra_kwargs) as input_file:
with open(filename, mode='w', **extra_kwargs) as output_file: with open(filename, mode='w', **extra_kwargs) as output_file:
# Using iter and readline is a workaround needed not to # Using iter and readline is a workaround needed not to
# disable input_file.tell(), which will happen if we call # disable input_file.tell(), which will happen if we call
@ -190,17 +195,19 @@ def groupid_to_group(x):
# If we stopped filtering at some point, reopen the file in # If we stopped filtering at some point, reopen the file in
# binary mode and copy verbatim the remaining part # binary mode and copy verbatim the remaining part
if current_position and stop_at: if current_position and stop_at:
with open(backup_filename, mode='rb') as input_file: with open(tmp_filename, mode='rb') as input_file:
input_file.seek(current_position) input_file.seek(current_position)
with open(filename, mode='ab') as output_file: with open(filename, mode='ab') as output_file:
output_file.writelines(input_file.readlines()) output_file.writelines(input_file.readlines())
except BaseException: except BaseException:
os.remove(tmp_filename)
# clean up the original file on failure. # clean up the original file on failure.
shutil.move(backup_filename, filename) shutil.move(backup_filename, filename)
raise raise
finally: finally:
os.remove(tmp_filename)
if not backup and os.path.exists(backup_filename): if not backup and os.path.exists(backup_filename):
os.remove(backup_filename) os.remove(backup_filename)

View File

@ -362,3 +362,27 @@ def test_filter_files_with_different_encodings(
with open(target_file, mode='r', **extra_kwargs) as f: with open(target_file, mode='r', **extra_kwargs) as f:
assert replacement in f.read() assert replacement in f.read()
def test_filter_files_multiple(tmpdir):
# All files given as input to this test must satisfy the pre-requisite
# that the 'replacement' string is not present in the file initially and
# that there's at least one match for the regex
original_file = os.path.join(
spack.paths.test_path, 'data', 'filter_file', 'x86_cpuid_info.c'
)
target_file = os.path.join(str(tmpdir), 'x86_cpuid_info.c')
shutil.copy(original_file, target_file)
# This should not raise exceptions
fs.filter_file(r'\<malloc.h\>', '<unistd.h>', target_file)
fs.filter_file(r'\<string.h\>', '<unistd.h>', target_file)
fs.filter_file(r'\<stdio.h\>', '<unistd.h>', target_file)
# Check the strings have been replaced
extra_kwargs = {}
if sys.version_info > (3, 0):
extra_kwargs = {'errors': 'surrogateescape'}
with open(target_file, mode='r', **extra_kwargs) as f:
assert '<malloc.h>' not in f.read()
assert '<string.h>' not in f.read()
assert '<stdio.h>' not in f.read()