unparser: do a better job of roundtripping strings
Handle complex f-strings. Backport of:
a993e901eb
#
This commit is contained in:
parent
e9612696fd
commit
ec16c2d7c2
@ -65,6 +65,11 @@ def interleave(inter, f, seq):
|
|||||||
f(x)
|
f(x)
|
||||||
|
|
||||||
|
|
||||||
|
_SINGLE_QUOTES = ("'", '"')
|
||||||
|
_MULTI_QUOTES = ('"""', "'''")
|
||||||
|
_ALL_QUOTES = _SINGLE_QUOTES + _MULTI_QUOTES
|
||||||
|
|
||||||
|
|
||||||
def is_simple_tuple(slice_value):
|
def is_simple_tuple(slice_value):
|
||||||
# when unparsing a non-empty tuple, the parantheses can be safely
|
# when unparsing a non-empty tuple, the parantheses can be safely
|
||||||
# omitted if there aren't any elements that explicitly requires
|
# omitted if there aren't any elements that explicitly requires
|
||||||
@ -86,7 +91,7 @@ class Unparser:
|
|||||||
output source code for the abstract syntax; original formatting
|
output source code for the abstract syntax; original formatting
|
||||||
is disregarded. """
|
is disregarded. """
|
||||||
|
|
||||||
def __init__(self, py_ver_consistent=False):
|
def __init__(self, py_ver_consistent=False, _avoid_backslashes=False):
|
||||||
"""Traverse an AST and generate its source.
|
"""Traverse an AST and generate its source.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -118,6 +123,7 @@ def __init__(self, py_ver_consistent=False):
|
|||||||
self._indent = 0
|
self._indent = 0
|
||||||
self._py_ver_consistent = py_ver_consistent
|
self._py_ver_consistent = py_ver_consistent
|
||||||
self._precedences = {}
|
self._precedences = {}
|
||||||
|
self._avoid_backslashes = _avoid_backslashes
|
||||||
|
|
||||||
def items_view(self, traverser, items):
|
def items_view(self, traverser, items):
|
||||||
"""Traverse and separate the given *items* with a comma and append it to
|
"""Traverse and separate the given *items* with a comma and append it to
|
||||||
@ -596,6 +602,53 @@ def _With(self, t):
|
|||||||
def _AsyncWith(self, t):
|
def _AsyncWith(self, t):
|
||||||
self._generic_With(t, async_=True)
|
self._generic_With(t, async_=True)
|
||||||
|
|
||||||
|
def _str_literal_helper(
|
||||||
|
self, string, quote_types=_ALL_QUOTES, escape_special_whitespace=False
|
||||||
|
):
|
||||||
|
"""Helper for writing string literals, minimizing escapes.
|
||||||
|
Returns the tuple (string literal to write, possible quote types).
|
||||||
|
"""
|
||||||
|
def escape_char(c):
|
||||||
|
# \n and \t are non-printable, but we only escape them if
|
||||||
|
# escape_special_whitespace is True
|
||||||
|
if not escape_special_whitespace and c in "\n\t":
|
||||||
|
return c
|
||||||
|
# Always escape backslashes and other non-printable characters
|
||||||
|
if c == "\\" or not c.isprintable():
|
||||||
|
return c.encode("unicode_escape").decode("ascii")
|
||||||
|
return c
|
||||||
|
|
||||||
|
escaped_string = "".join(map(escape_char, string))
|
||||||
|
possible_quotes = quote_types
|
||||||
|
if "\n" in escaped_string:
|
||||||
|
possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
|
||||||
|
possible_quotes = [q for q in possible_quotes if q not in escaped_string]
|
||||||
|
if not possible_quotes:
|
||||||
|
# If there aren't any possible_quotes, fallback to using repr
|
||||||
|
# on the original string. Try to use a quote from quote_types,
|
||||||
|
# e.g., so that we use triple quotes for docstrings.
|
||||||
|
string = repr(string)
|
||||||
|
quote = next((q for q in quote_types if string[0] in q), string[0])
|
||||||
|
return string[1:-1], [quote]
|
||||||
|
if escaped_string:
|
||||||
|
# Sort so that we prefer '''"''' over """\""""
|
||||||
|
possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
|
||||||
|
# If we're using triple quotes and we'd need to escape a final
|
||||||
|
# quote, escape it
|
||||||
|
if possible_quotes[0][0] == escaped_string[-1]:
|
||||||
|
assert len(possible_quotes[0]) == 3
|
||||||
|
escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
|
||||||
|
return escaped_string, possible_quotes
|
||||||
|
|
||||||
|
def _write_str_avoiding_backslashes(self, string, quote_types=_ALL_QUOTES):
|
||||||
|
"""Write string literal value w/a best effort attempt to avoid backslashes."""
|
||||||
|
string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
|
||||||
|
quote_type = quote_types[0]
|
||||||
|
self.write("{quote_type}{string}{quote_type}".format(
|
||||||
|
quote_type=quote_type,
|
||||||
|
string=string,
|
||||||
|
))
|
||||||
|
|
||||||
# expr
|
# expr
|
||||||
def _Bytes(self, t):
|
def _Bytes(self, t):
|
||||||
self.write(repr(t.s))
|
self.write(repr(t.s))
|
||||||
@ -625,33 +678,53 @@ def _Str(self, tree):
|
|||||||
def _JoinedStr(self, t):
|
def _JoinedStr(self, t):
|
||||||
# JoinedStr(expr* values)
|
# JoinedStr(expr* values)
|
||||||
self.write("f")
|
self.write("f")
|
||||||
|
|
||||||
|
if self._avoid_backslashes:
|
||||||
string = StringIO()
|
string = StringIO()
|
||||||
self._fstring_JoinedStr(t, string.write)
|
self._fstring_JoinedStr(t, string.write)
|
||||||
# Deviation from `unparse.py`: Try to find an unused quote.
|
self._write_str_avoiding_backslashes(string.getvalue())
|
||||||
# This change is made to handle _very_ complex f-strings.
|
return
|
||||||
v = string.getvalue()
|
|
||||||
if '\n' in v or '\r' in v:
|
# If we don't need to avoid backslashes globally (i.e., we only need
|
||||||
quote_types = ["'''", '"""']
|
# to avoid them inside FormattedValues), it's cosmetically preferred
|
||||||
else:
|
# to use escaped whitespace. That is, it's preferred to use backslashes
|
||||||
quote_types = ["'", '"', '"""', "'''"]
|
# for cases like: f"{x}\n". To accomplish this, we keep track of what
|
||||||
for quote_type in quote_types:
|
# in our buffer corresponds to FormattedValues and what corresponds to
|
||||||
if quote_type not in v:
|
# Constant parts of the f-string, and allow escapes accordingly.
|
||||||
v = "{quote_type}{v}{quote_type}".format(quote_type=quote_type, v=v)
|
buffer = []
|
||||||
break
|
for value in t.values:
|
||||||
else:
|
meth = getattr(self, "_fstring_" + type(value).__name__)
|
||||||
v = repr(v)
|
string = StringIO()
|
||||||
self.write(v)
|
meth(value, string.write)
|
||||||
|
buffer.append((string.getvalue(), isinstance(value, ast.Constant)))
|
||||||
|
new_buffer = []
|
||||||
|
quote_types = _ALL_QUOTES
|
||||||
|
for value, is_constant in buffer:
|
||||||
|
# Repeatedly narrow down the list of possible quote_types
|
||||||
|
value, quote_types = self._str_literal_helper(
|
||||||
|
value, quote_types=quote_types,
|
||||||
|
escape_special_whitespace=is_constant
|
||||||
|
)
|
||||||
|
new_buffer.append(value)
|
||||||
|
value = "".join(new_buffer)
|
||||||
|
quote_type = quote_types[0]
|
||||||
|
self.write("{quote_type}{value}{quote_type}".format(
|
||||||
|
quote_type=quote_type,
|
||||||
|
value=value,
|
||||||
|
))
|
||||||
|
|
||||||
def _FormattedValue(self, t):
|
def _FormattedValue(self, t):
|
||||||
# FormattedValue(expr value, int? conversion, expr? format_spec)
|
# FormattedValue(expr value, int? conversion, expr? format_spec)
|
||||||
self.write("f")
|
self.write("f")
|
||||||
string = StringIO()
|
string = StringIO()
|
||||||
self._fstring_JoinedStr(t, string.write)
|
self._fstring_JoinedStr(t, string.write)
|
||||||
self.write(repr(string.getvalue()))
|
self._write_str_avoiding_backslashes(string.getvalue())
|
||||||
|
|
||||||
def _fstring_JoinedStr(self, t, write):
|
def _fstring_JoinedStr(self, t, write):
|
||||||
for value in t.values:
|
for value in t.values:
|
||||||
|
print(" ", value)
|
||||||
meth = getattr(self, "_fstring_" + type(value).__name__)
|
meth = getattr(self, "_fstring_" + type(value).__name__)
|
||||||
|
print(meth)
|
||||||
meth(value, write)
|
meth(value, write)
|
||||||
|
|
||||||
def _fstring_Str(self, t, write):
|
def _fstring_Str(self, t, write):
|
||||||
@ -667,13 +740,18 @@ def _fstring_FormattedValue(self, t, write):
|
|||||||
write("{")
|
write("{")
|
||||||
|
|
||||||
expr = StringIO()
|
expr = StringIO()
|
||||||
unparser = type(self)(py_ver_consistent=self._py_ver_consistent)
|
unparser = type(self)(
|
||||||
|
py_ver_consistent=self._py_ver_consistent,
|
||||||
|
_avoid_backslashes=True,
|
||||||
|
)
|
||||||
unparser.set_precedence(pnext(_Precedence.TEST), t.value)
|
unparser.set_precedence(pnext(_Precedence.TEST), t.value)
|
||||||
unparser.visit(t.value, expr)
|
unparser.visit(t.value, expr)
|
||||||
expr = expr.getvalue().rstrip("\n")
|
expr = expr.getvalue().rstrip("\n")
|
||||||
|
|
||||||
if expr.startswith("{"):
|
if expr.startswith("{"):
|
||||||
write(" ") # Separate pair of opening brackets as "{ {"
|
write(" ") # Separate pair of opening brackets as "{ {"
|
||||||
|
if "\\" in expr:
|
||||||
|
raise ValueError("Unable to avoid backslash in f-string expression part")
|
||||||
write(expr)
|
write(expr)
|
||||||
if t.conversion != -1:
|
if t.conversion != -1:
|
||||||
conversion = chr(t.conversion)
|
conversion = chr(t.conversion)
|
||||||
@ -707,6 +785,8 @@ def _write_constant(self, value):
|
|||||||
if raw.startswith(r"'\\u"):
|
if raw.startswith(r"'\\u"):
|
||||||
raw = "'\\" + raw[3:]
|
raw = "'\\" + raw[3:]
|
||||||
self.write(raw)
|
self.write(raw)
|
||||||
|
elif self._avoid_backslashes and isinstance(value, str):
|
||||||
|
self._write_str_avoiding_backslashes(value)
|
||||||
else:
|
else:
|
||||||
self.write(repr(value))
|
self.write(repr(value))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user