Files
why3/bench/test_mlw_printer

173 lines
6.2 KiB
Python
Executable File

#!/usr/bin/env python3
import sys
import sexpdata
import math
import os
from subprocess import Popen, PIPE, DEVNULL
debug = os.environ.get("DEBUG") != None
loads_kwargs = dict(nil=None, true=None) # Don't convert nil or #t in s-exp
class CommandFailed(Exception):
def __init__(self, msg):
self.msg = msg
class NotEqual(Exception):
def __init__(self, path, sexp0, sexp1):
self.path = path
self.sexp0 = sexp0
self.sexp1 = sexp1
# Read a whyml file as a s-exp
def read(why3, filename):
p = Popen([why3, "pp", "--output=sexp", filename], stdout=PIPE, stderr=DEVNULL, encoding='utf8')
s, _ = p.communicate()
if p.returncode == 0:
return sexpdata.loads(s, **loads_kwargs) if s else None
else:
raise CommandFailed("cannot print s-expr for original file (returncode={})"
.format(p.returncode))
# Pretty-print a whyml file and read the result as a s-exp
def print_and_read(why3, filename):
p1 = Popen([why3, "pp", "--output=mlw", filename], stdout=PIPE, stderr=DEVNULL, encoding='utf8')
p2 = Popen([why3, "pp", "--output=sexp", "-"], stdin=p1.stdout, stdout=PIPE, stderr=DEVNULL, encoding='utf8')
s, _ = p2.communicate()
if p2.returncode == 0:
return sexpdata.loads(s, **loads_kwargs) if s else None
else:
raise CommandFailed("cannot print s-expr for pretty-printed output (returncode={})"
.format(p2.returncode))
def is_location(sexp):
try:
return [type(x) for x in sexp] == [str, int, int, int]
except:
return False
IGNORE_ID_ATTRS = [
"W:unused_variable:N", "extraction:array_make", "extraction:array",
"induction", "mlw:reference_var", "infer", "useraxiom", "W:non_conservative_extension:N",
"model_trace:flag", "model_trace:first_val", "model_trace:sec_val", "model_trace:TEMP_NAME",
"model", "W:unmodified_variable:N"
]
def keep_id_attr(at):
try:
# (ATstr ((attr_string "...") (attr_tag N)))
variant, fields = at
field0, field1 = fields
field, value = field0
if variant.value() == 'ATstr' and field.value() == 'attr_string':
return value not in IGNORE_ID_ATTRS
else:
return True
except:
return True
def ignore_id_attrs(sexp):
try:
# (id_ats (at ...))
field, value = sexp
if field.value() == 'id_ats':
ats = [at for at in value if keep_id_attr(at)]
return [sexp[0], ats]
else:
return sexp
except:
return sexp
# Test for sexp (<field_name> _)
def is_field(sexp, field_name):
try:
return len(sexp) == 2 and sexp[0].value() == field_name
except:
return False
def assert_equal(path, sexp0, sexp1):
if sexp0 == sexp1:
return
if is_location(sexp0) and is_location(sexp1):
return # Don't bother with locations
if is_field(sexp0, "attr_tag") and is_field(sexp1, "attr_tag"):
return # Don't bother with tags
if type(sexp0) == float and math.isnan(sexp0) and type(sexp1) == float and math.isnan(sexp1):
return # nan != nan
if type(sexp0) == list and type(sexp1) == list:
while True: # Ignore additional parentheses
try:
if sexp0[0].value() == "PTparen" and sexp1[0].value() != "PTparen":
path = path+[1]
sexp0 = sexp0[1]
elif sexp1[0].value() == "PTparen" and sexp0[0].value() != "PTparen":
sexp1 = sexp1[1]
elif sexp0[0].value() == "Pparen" and sexp1[0].value() != "Pparen":
path = path+[1]
sexp0 = sexp0[1][0][1]
elif sexp1[0].value() == "Pparen" and sexp0[0].value() != "Pparen":
sexp1 = sexp1[1][0][1]
elif sexp0[0].value() == "Ptuple" and len(sexp0[1]) == 1 and sexp1[0].value() != "Ptuple":
path = path+[1]
sexp0 = sexp0[1][0][1]
elif sexp1[0].value() == "Ptuple" and len(sexp1[1]) == 1 and sexp0[0].value() != "Ptuple":
sexp1 = sexp1[1][0][1]
elif ((sexp0[0].value() == "Tinfix" and sexp1[0].value() == "Tinnfix") or
(sexp0[0].value() == "Tinnfix" and sexp1[0].value() == "Tinfix") or
(sexp0[0].value() == "Tbinop" and sexp1[0].value() == "Tbinnop") or
(sexp0[0].value() == "Tbinnop" and sexp1[0].value() == "Tbinop") or
(sexp0[0].value() == "Einfix" and sexp1[0].value() == "Einnfix") or
(sexp0[0].value() == "Einnfix" and sexp1[0].value() == "Einfix")):
sexp0 = sexp0[1:]
sexp1 = sexp1[1:]
else:
break
except AttributeError:
break
sexp0 = ignore_id_attrs(sexp0)
sexp1 = ignore_id_attrs(sexp1)
if len(sexp0) > len(sexp1):
raise NotEqual(path, sexp0, sexp1)
if len(sexp0) < len(sexp1):
raise NotEqual(path, sexp0, sexp1)
for i, (s0, s1) in enumerate(zip(sexp0, sexp1)):
assert_equal(path+[i], s0, s1)
else:
raise NotEqual(path, sexp0, sexp1)
def trace(path, sexp, sexp1):
if path == []:
return [sexpdata.Symbol("ERROR"),
[sexpdata.Symbol("EXPECTED"), sexp],
[sexpdata.Symbol("FOUND"), sexp1]]
elif type(sexp) == list:
return [trace(path[1:], sexp[i], sexp1)
if i == path[0] else sexp[i]
for i, x in enumerate(sexp)]
def test(why3, filename):
print(" {}: ".format(filename), end='', flush=True)
try:
sexp0 = read(why3, filename)
sexp1 = print_and_read(why3, filename)
assert_equal([], sexp0, sexp1)
print("ok")
return True
except NotEqual as e:
print("FAILED")
if debug:
sexpdata.dump(trace(e.path, sexp0, e.sexp1) or "NO TRACE", sys.stdout)
return False
except CommandFailed as e:
print("COMMAND FAILED:", e.msg)
return False
def main():
why3 = sys.argv[1]
files = sys.argv[2:]
res = all(test(why3, f) for f in files)
exit(0 if res else 1)
if __name__ == "__main__":
main()