wonderful world of testing

This commit is contained in:
Bryan Bishop 2012-03-24 21:34:19 -05:00
parent 3bd84c1dac
commit 33d8c7a117
2 changed files with 143 additions and 43 deletions

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
#utilities to help disassemble pokémon crystal #utilities to help disassemble pokémon crystal
import sys, os, inspect, md5 import sys, os, inspect, md5, json
from copy import copy from copy import copy
#for IntervalMap #for IntervalMap
@ -10,6 +10,9 @@ from itertools import izip
#for testing all this crap #for testing all this crap
import unittest2 as unittest import unittest2 as unittest
if not hasattr(json, "dumps"):
json.dumps = json.write
#table of pointers to map groups #table of pointers to map groups
#each map group contains some number of map headers #each map group contains some number of map headers
map_group_pointer_table = 0x94000 map_group_pointer_table = 0x94000
@ -4435,13 +4438,22 @@ def isolate_incbins():
def process_incbins(): def process_incbins():
"parse incbin lines into memory" "parse incbin lines into memory"
global incbins global asm, incbin_lines, processed_incbins
incbins = {} #reset #load asm if it isn't ready yet
if asm == [] or asm == None:
load_asm()
#get a list of incbins if that hasn't happened yet
if incbin_lines == [] or incbin_lines == None:
isolate_incbins()
#reset the global that this function creates
processed_incbins = {}
#for each incbin..
for incbin in incbin_lines: for incbin in incbin_lines:
#reset this entry
processed_incbin = {} processed_incbin = {}
#get the line number from the global asm line list
line_number = asm.index(incbin) line_number = asm.index(incbin)
#forget about all the leading characters
partial_start = incbin[21:] partial_start = incbin[21:]
start = partial_start.split(",")[0].replace("$", "0x") start = partial_start.split(",")[0].replace("$", "0x")
start = eval(start) start = eval(start)
@ -4456,20 +4468,19 @@ def process_incbins():
end = start + interval end = start + interval
end_hex = hex(end).replace("0x", "$") end_hex = hex(end).replace("0x", "$")
processed_incbin = { processed_incbin = {"line_number": line_number,
"line_number": line_number,
"line": incbin, "line": incbin,
"start": start, "start": start,
"interval": interval, "interval": interval,
"end": end, "end": end, }
}
#don't add this incbin if the interval is 0 #don't add this incbin if the interval is 0
if interval != 0: if interval != 0:
processed_incbins[line_number] = processed_incbin processed_incbins[line_number] = processed_incbin
return processed_incbins
def reset_incbins(): def reset_incbins():
"reset asm before inserting another diff" "reset asm before inserting another diff"
global asm, incbin_lines, processed_incbins
asm = None asm = None
incbin_lines = [] incbin_lines = []
processed_incbins = {} processed_incbins = {}
@ -4580,7 +4591,7 @@ def apply_diff(diff, try_fixing=True, do_compile=True):
#confirm it's working #confirm it's working
if do_compile: if do_compile:
try: try:
subprocess.check_call("cd ../; make clean; LC_CTYPE=C make", shell=True) subprocess.check_call("cd ../; make clean; make", shell=True)
return True return True
except Exception, exc: except Exception, exc:
if try_fixing: if try_fixing:
@ -4592,13 +4603,6 @@ def index(seq, f):
where f(item) == True.""" where f(item) == True."""
return next((i for i in xrange(len(seq)) if f(seq[i])), None) return next((i for i in xrange(len(seq)) if f(seq[i])), None)
def is_probably_pointer(input):
try:
blah = int(input, 16)
return True
except:
return False
def analyze_intervals(): def analyze_intervals():
"""find the largest baserom.gbc intervals""" """find the largest baserom.gbc intervals"""
global asm, processed_incbins global asm, processed_incbins
@ -4614,10 +4618,11 @@ def analyze_intervals():
results.append(processed_incbins[key]) results.append(processed_incbins[key])
return results return results
def write_all_labels(all_labels): def write_all_labels(all_labels, filename="labels.json"):
fh = open("labels.json", "w") fh = open(filename, "w")
fh.write(json.dumps(all_labels)) fh.write(json.dumps(all_labels))
fh.close() fh.close()
return True
def remove_quoted_text(line): def remove_quoted_text(line):
"""get rid of content inside quotes """get rid of content inside quotes
@ -4632,9 +4637,12 @@ def remove_quoted_text(line):
line = line[0:first] + line[second+1:] line = line[0:first] + line[second+1:]
return line return line
def line_has_comment_address(line, returnable={}): def line_has_comment_address(line, returnable={}, bank=None):
"""checks that a given line has a comment """checks that a given line has a comment
with a valid address""" with a valid address, and returns the address in the object.
Note: bank is required if you have a 4-letter-or-less address,
because otherwise there is no way to figure out which bank
is curretly being scanned."""
#first set the bank/offset to nada #first set the bank/offset to nada
returnable["bank"] = None returnable["bank"] = None
returnable["offset"] = None returnable["offset"] = None
@ -4658,7 +4666,7 @@ def line_has_comment_address(line, returnable={}):
if line[-2:] == "; ": if line[-2:] == "; ":
return False return False
#and multiple whitespace doesn't count either #and multiple whitespace doesn't count either
line = line.rstrip(" ") line = line.rstrip(" ").lstrip(" ")
if line[-1] == ";": if line[-1] == ";":
return False return False
#there must be more content after the semicolon #there must be more content after the semicolon
@ -4675,7 +4683,7 @@ def line_has_comment_address(line, returnable={}):
token = comment.split(" ")[0] token = comment.split(" ")[0]
if token in ["0x", "$", "x", ":"]: if token in ["0x", "$", "x", ":"]:
return False return False
bank, offset = None, None offset = None
#process a token with a A:B format #process a token with a A:B format
if ":" in token: #3:3F0A, $3:$3F0A, 0x3:0x3F0A, 3:3F0A if ":" in token: #3:3F0A, $3:$3F0A, 0x3:0x3F0A, 3:3F0A
#split up the token #split up the token
@ -4717,10 +4725,8 @@ def line_has_comment_address(line, returnable={}):
elif "$" in token and not "x" in token: elif "$" in token and not "x" in token:
token = token.replace("$", "0x") token = token.replace("$", "0x")
offset = int(token, 16) offset = int(token, 16)
bank = calculate_bank(offset)
elif "0x" in token and not "$" in token: elif "0x" in token and not "$" in token:
offset = int(token, 16) offset = int(token, 16)
bank = calculate_bank(offset)
else: #might just be "1" at this point else: #might just be "1" at this point
token = token.lower() token = token.lower()
#check if there are bad characters #check if there are bad characters
@ -4728,9 +4734,10 @@ def line_has_comment_address(line, returnable={}):
if c not in valid: if c not in valid:
return False return False
offset = int(token, 16) offset = int(token, 16)
bank = calculate_bank(offset)
if offset == None and bank == None: if offset == None and bank == None:
return False return False
if bank == None:
bank = calculate_bank(offset)
returnable["bank"] = bank returnable["bank"] = bank
returnable["offset"] = offset returnable["offset"] = offset
returnable["address"] = calculate_pointer(offset, bank=bank) returnable["address"] = calculate_pointer(offset, bank=bank)
@ -4773,7 +4780,7 @@ def find_labels_without_addresses():
return without_addresses return without_addresses
label_errors = "" label_errors = ""
def get_labels_between(start_line_id, end_line_id, bank_id): def get_labels_between(start_line_id, end_line_id, bank):
labels = [] labels = []
#label = { #label = {
# "line_number": 15, # "line_number": 15,
@ -4782,6 +4789,8 @@ def get_labels_between(start_line_id, end_line_id, bank_id):
# "offset": 0x5315, # "offset": 0x5315,
# "address": 0x75315, # "address": 0x75315,
#} #}
if asm == None:
load_asm()
sublines = asm[start_line_id : end_line_id + 1] sublines = asm[start_line_id : end_line_id + 1]
for (current_line_offset, line) in enumerate(sublines): for (current_line_offset, line) in enumerate(sublines):
#skip lines without labels #skip lines without labels
@ -4794,7 +4803,7 @@ def get_labels_between(start_line_id, end_line_id, bank_id):
#setup a place to store return values from line_has_comment_address #setup a place to store return values from line_has_comment_address
returnable = {} returnable = {}
#get the address from the comment #get the address from the comment
has_comment = line_has_comment_address(line, returnable=returnable) has_comment = line_has_comment_address(line, returnable=returnable, bank=bank)
#skip this line if it has no address in the comment #skip this line if it has no address in the comment
if not has_comment: continue if not has_comment: continue
#parse data from line_has_comment_address #parse data from line_has_comment_address
@ -4813,7 +4822,7 @@ def get_labels_between(start_line_id, end_line_id, bank_id):
labels.append(label) labels.append(label)
return labels return labels
def scan_for_predefined_labels(): def scan_for_predefined_labels(debug=False):
"""looks through the asm file for labels at specific addresses, """looks through the asm file for labels at specific addresses,
this relies on the label having its address after. ex: this relies on the label having its address after. ex:
@ -4825,8 +4834,9 @@ def scan_for_predefined_labels():
addresses, but faster to write this script. rgbasm would be able addresses, but faster to write this script. rgbasm would be able
to grab all label addresses better than this script.. to grab all label addresses better than this script..
""" """
bank_intervals = {} global all_labels
all_labels = [] all_labels = []
bank_intervals = {}
#figure out line numbers for each bank #figure out line numbers for each bank
for bank_id in range(0x7F+1): for bank_id in range(0x7F+1):
@ -4836,29 +4846,34 @@ def scan_for_predefined_labels():
abbreviation = "0" abbreviation = "0"
abbreviation_next = "1" abbreviation_next = "1"
#calculate the start/stop line numbers for this bank
start_line_id = index(asm, lambda line: "\"bank" + abbreviation + "\"" in line) start_line_id = index(asm, lambda line: "\"bank" + abbreviation + "\"" in line)
if bank_id != 0x7F:
if bank_id != 0x2c:
end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next + "\"" in line) end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next + "\"" in line)
end_line_id += 1
else: else:
end_line_id = len(asm) - 1 end_line_id = len(asm) - 1
print "bank" + abbreviation + " starts at " + str(start_line_id) + " to " + str(end_line_id) if debug:
output = "bank" + abbreviation + " starts at "
output += str(start_line_id)
output += " to "
output += str(end_line_id)
print output
bank_intervals[bank_id] = { #store the start/stop line number for this bank
"start": start_line_id, bank_intervals[bank_id] = {"start": start_line_id,
"end": end_line_id, "end": end_line_id,}
} #for each bank..
for bank_id in bank_intervals.keys(): for bank_id in bank_intervals.keys():
#get the start/stop line number
bank_data = bank_intervals[bank_id] bank_data = bank_intervals[bank_id]
start_line_id = bank_data["start"] start_line_id = bank_data["start"]
end_line_id = bank_data["end"] end_line_id = bank_data["end"]
#get all labels between these two lines
labels = get_labels_between(start_line_id, end_line_id, bank_id) labels = get_labels_between(start_line_id, end_line_id, bank_id)
#bank_intervals[bank_id]["labels"] = labels #bank_intervals[bank_id]["labels"] = labels
all_labels.extend(labels) all_labels.extend(labels)
write_all_labels(all_labels) write_all_labels(all_labels)
return all_labels return all_labels
@ -5112,6 +5127,9 @@ class TestAsmList(unittest.TestCase):
self.assertTrue(x(";3:FFAA")) self.assertTrue(x(";3:FFAA"))
self.assertFalse(x('hello world "how are you today;0x1"')) self.assertFalse(x('hello world "how are you today;0x1"'))
self.assertTrue(x('hello world "how are you today:0x1";1')) self.assertTrue(x('hello world "how are you today:0x1";1'))
returnable = {}
self.assertTrue(x("hello_world: ; 0x4050", returnable=returnable, bank=5))
self.assertTrue(returnable["address"] == 0x14050)
def test_line_has_label(self): def test_line_has_label(self):
x = line_has_label x = line_has_label
self.assertTrue(x("hi:")) self.assertTrue(x("hi:"))
@ -5135,6 +5153,88 @@ class TestAsmList(unittest.TestCase):
labels = find_labels_without_addresses() labels = find_labels_without_addresses()
self.failUnless(len(labels) == 0) self.failUnless(len(labels) == 0)
asm = None asm = None
def test_get_labels_between(self):
global asm
x = get_labels_between#(start_line_id, end_line_id, bank)
asm = ["HelloWorld: ;1",
"hi:",
"no label on this line",
]
labels = x(0, 2, 0x12)
self.assertEqual(len(labels), 1)
self.assertEqual(labels[0]["label"], "HelloWorld")
del asm
def test_scan_for_predefined_labels(self):
#label keys: line_number, bank, label, offset, address
load_asm()
all_labels = scan_for_predefined_labels()
label_names = [x["label"] for x in all_labels]
self.assertIn("GetFarByte", label_names)
self.assertIn("AddNTimes", label_names)
self.assertIn("CheckShininess", label_names)
def test_write_all_labels(self):
"""dumping json into a file"""
filename = "test_labels.json"
#remove the current file
if os.path.exists(filename):
os.system("rm " + filename)
#make up some labels
labels = []
#fake label 1
label = {"line_number": 5, "bank": 0, "label": "SomeLabel", "address": 0x10}
labels.append(label)
#fake label 2
label = {"line_number": 15, "bank": 2, "label": "SomeOtherLabel", "address": 0x9F0A}
labels.append(label)
#dump to file
write_all_labels(labels, filename=filename)
#open the file and read the contents
file_handler = open(filename, "r")
contents = file_handler.read()
file_handler.close()
#parse into json
obj = json.read(contents)
#begin testing
self.assertEqual(len(obj), len(labels))
self.assertEqual(len(obj), 2)
self.assertEqual(obj, labels)
def test_isolate_incbins(self):
global asm
asm = ["123", "456", "789", "abc", "def", "ghi",
'INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
"jkl",
'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
lines = isolate_incbins()
self.assertIn(asm[6], lines)
self.assertIn(asm[8], lines)
for line in lines:
self.assertIn("baserom", line)
def test_process_incbins(self):
global incbin_lines, processed_incbins, asm
incbin_lines = ['INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
asm = copy(incbin_lines)
asm.insert(1, "some other random line")
processed_incbins = process_incbins()
self.assertEqual(len(processed_incbins), len(incbin_lines))
self.assertEqual(processed_incbins[0]["line"], incbin_lines[0])
self.assertEqual(processed_incbins[2]["line"], incbin_lines[1])
def test_reset_incbins(self):
global asm, incbin_lines, processed_incbins
#temporarily override the functions
global load_asm, isolate_incbins, process_incbins
temp1, temp2, temp3 = load_asm, isolate_incbins, process_incbins
def load_asm(): pass
def isolate_incbins(): pass
def process_incbins(): pass
#call reset
reset_incbins()
#check the results
self.assertTrue(asm == [] or asm == None)
self.assertTrue(incbin_lines == [])
self.assertTrue(processed_incbins == {})
#reset the original functions
load_asm, isolate_incbins, process_incbins = temp1, temp2, temp3
class TestMapParsing(unittest.TestCase): class TestMapParsing(unittest.TestCase):
#def test_parse_warp_bytes(self): #def test_parse_warp_bytes(self):
# pass #or raise NotImplementedError, bryan_message # pass #or raise NotImplementedError, bryan_message

View File

@ -24,7 +24,7 @@ GetFarByte: ; 0x304d
INCBIN "baserom.gbc",$305d,$30fe-$305d INCBIN "baserom.gbc",$305d,$30fe-$305d
AddNTimes ; 0x30fe AddNTimes: ; 0x30fe
and a and a
ret z ret z
.loop .loop