lots of asm-related code and tests

This commit is contained in:
Bryan Bishop 2012-03-24 18:01:37 -05:00
parent f93de7b1bd
commit 3bd84c1dac

View File

@ -4405,6 +4405,463 @@ for map_group_id in map_names.keys():
#set the value in the original dictionary #set the value in the original dictionary
map_names[map_group_id][map_id]["label"] = cleaned_name map_names[map_group_id][map_id]["label"] = cleaned_name
#### asm utilities ####
#these are pulled in from pokered/extras/analyze_incbins.py
#store each line of source code here
asm = None
#store each incbin line separately
incbin_lines = []
#storage for processed incbin lines
processed_incbins = {}
def isolate_incbins():
"find each incbin line"
global incbin_lines
incbin_lines = []
for line in asm:
if line == "": continue
if line.count(" ") == len(line): continue
#clean up whitespace at beginning of line
while line[0] == " ":
line = line[1:]
if line[0:6] == "INCBIN" and "baserom.gbc" in line:
incbin_lines.append(line)
return incbin_lines
def process_incbins():
"parse incbin lines into memory"
global incbins
incbins = {} #reset
for incbin in incbin_lines:
processed_incbin = {}
line_number = asm.index(incbin)
partial_start = incbin[21:]
start = partial_start.split(",")[0].replace("$", "0x")
start = eval(start)
start_hex = hex(start).replace("0x", "$")
partial_interval = incbin[21:].split(",")[1]
partial_interval = partial_interval.replace(";", "#")
partial_interval = partial_interval.replace("$", "0x").replace("0xx", "0x")
interval = eval(partial_interval)
interval_hex = hex(interval).replace("0x", "$").replace("x", "")
end = start + interval
end_hex = hex(end).replace("0x", "$")
processed_incbin = {
"line_number": line_number,
"line": incbin,
"start": start,
"interval": interval,
"end": end,
}
#don't add this incbin if the interval is 0
if interval != 0:
processed_incbins[line_number] = processed_incbin
def reset_incbins():
"reset asm before inserting another diff"
asm = None
incbin_lines = []
processed_incbins = {}
load_asm()
isolate_incbins()
process_incbins()
def find_incbin_to_replace_for(address, debug=False, rom_file="../baserom.gbc"):
"""returns a line number for which incbin to edit
if you were to insert bytes into main.asm"""
if type(address) == str: address = int(address, 16)
if not (0 <= address <= os.lstat(rom_file).st_size):
raise IndexError, "address is out of bounds"
for incbin_key in processed_incbins.keys():
incbin = processed_incbins[incbin_key]
start = incbin["start"]
end = incbin["end"]
if debug:
print "start is: " + str(start)
print "end is: " + str(end)
print "address is: " + str(type(address))
print "checking.... " + hex(start) + " <= " + hex(address) + " <= " + hex(end)
if start <= address <= end:
return incbin_key
return None
def split_incbin_line_into_three(line, start_address, byte_count):
"""
splits an incbin line into three pieces.
you can replace the middle one with the new content of length bytecount
start_address: where you want to start inserting bytes
byte_count: how many bytes you will be inserting
"""
if type(start_address) == str: start_address = int(start_address, 16)
if not (0 <= start_address <= os.lstat(rom_file).st_size):
raise IndexError, "start_address is out of bounds"
if len(processed_incbins) == 0:
raise Exception, "processed_incbins must be populated"
original_incbin = processed_incbins[line]
start = original_incbin["start"]
end = original_incbin["end"]
#start, end1, end2 (to be printed as start, end1 - end2)
if start_address - start > 0:
first = (start, start_address, start)
else:
first = (None) #skip this one because we're not including anything
#this is the one you will replace with whatever content
second = (start_address, byte_count)
third = (start_address + byte_count, end - (start_address + byte_count))
output = ""
if first:
output += "INCBIN \"baserom.gbc\",$" + hex(first[0])[2:] + ",$" + hex(first[1])[2:] + " - $" + hex(first[2])[2:] + "\n"
output += "INCBIN \"baserom.gbc\",$" + hex(second[0])[2:] + "," + str(byte_count) + "\n"
output += "INCBIN \"baserom.gbc\",$" + hex(third[0])[2:] + ",$" + hex(third[1])[2:] #no newline
return output
def generate_diff_insert(line_number, newline):
original = "\n".join(line for line in asm)
newfile = deepcopy(asm)
newfile[line_number] = newline #possibly inserting multiple lines
newfile = "\n".join(line for line in newfile)
original_filename = "ejroqjfoad.temp"
newfile_filename = "fjiqefo.temp"
original_fh = open(original_filename, "w")
original_fh.write(original)
original_fh.close()
newfile_fh = open(newfile_filename, "w")
newfile_fh.write(newfile)
newfile_fh.close()
try:
diffcontent = subprocess.check_output("diff -u ../main.asm " + newfile_filename, shell=True)
except AttributeError, exc:
raise exc
except Exception, exc:
diffcontent = exc.output
os.system("rm " + original_filename)
os.system("rm " + newfile_filename)
return diffcontent
def apply_diff(diff, try_fixing=True, do_compile=True):
print "... Applying diff."
#write the diff to a file
fh = open("temp.patch", "w")
fh.write(diff)
fh.close()
#apply the patch
os.system("cp ../main.asm ../main1.asm")
os.system("patch ../main.asm temp.patch")
#remove the patch
os.system("rm temp.patch")
#confirm it's working
if do_compile:
try:
subprocess.check_call("cd ../; make clean; LC_CTYPE=C make", shell=True)
return True
except Exception, exc:
if try_fixing:
os.system("mv ../main1.asm ../main.asm")
return False
def index(seq, f):
"""return the index of the first item in seq
where f(item) == True."""
return next((i for i in xrange(len(seq)) if f(seq[i])), None)
def is_probably_pointer(input):
try:
blah = int(input, 16)
return True
except:
return False
def analyze_intervals():
"""find the largest baserom.gbc intervals"""
global asm, processed_incbins
if asm == None:
load_asm()
if processed_incbins == {}:
isolate_incbins()
process_incbins()
results = []
ordered_keys = sorted(processed_incbins, key=lambda entry: processed_incbins[entry]["interval"])
ordered_keys.reverse()
for key in ordered_keys:
results.append(processed_incbins[key])
return results
def write_all_labels(all_labels):
fh = open("labels.json", "w")
fh.write(json.dumps(all_labels))
fh.close()
def remove_quoted_text(line):
"""get rid of content inside quotes
and also removes the quotes from the input string"""
while line.count("\"") % 2 == 0 and line.count("\"") > 0:
first = line.find("\"")
second = line.find("\"", first+1)
line = line[0:first] + line[second+1:]
while line.count("\'") % 2 == 0 and line.count("'") > 0:
first = line.find("\'")
second = line.find("\'", first+1)
line = line[0:first] + line[second+1:]
return line
def line_has_comment_address(line, returnable={}):
"""checks that a given line has a comment
with a valid address"""
#first set the bank/offset to nada
returnable["bank"] = None
returnable["offset"] = None
returnable["address"] = None
#only valid characters are 0-9A-F
valid = [str(x) for x in range(0,10)] + [chr(x) for x in range(97, 102+1)]
#check if there is a comment in this line
if ";" not in line:
return False
#first throw away anything in quotes
if (line.count("\"") % 2 == 0 and line.count("\"")!=0) \
or (line.count("\'") % 2 == 0 and line.count("\'")!=0):
line = remove_quoted_text(line)
#check if there is still a comment in this line after quotes removed
if ";" not in line:
return False
#but even if there's a semicolon there must be later text
if line[-1] == ";":
return False
#and just a space doesn't count
if line[-2:] == "; ":
return False
#and multiple whitespace doesn't count either
line = line.rstrip(" ")
if line[-1] == ";":
return False
#there must be more content after the semicolon
if len(line)-1 == line.find(";"):
return False
#split it up into the main comment part
comment = line[line.find(";")+1:]
#don't want no leading whitespace
comment = comment.lstrip(" ").rstrip(" ")
#split up multi-token comments into single tokens
token = comment
if " " in comment:
#use the first token in the comment
token = comment.split(" ")[0]
if token in ["0x", "$", "x", ":"]:
return False
bank, offset = None, None
#process a token with a A:B format
if ":" in token: #3:3F0A, $3:$3F0A, 0x3:0x3F0A, 3:3F0A
#split up the token
bank_piece = token.split(":")[0].lower()
offset_piece = token.split(":")[1].lower()
#filter out blanks/duds
if bank_piece in ["$", "0x", "x"] \
or offset_piece in ["$", "0x", "x"]:
return False
#they can't have both "$" and "x"
if "$" in bank_piece and "x" in bank_piece:
return False
if "$" in offset_piece and "x" in offset_piece:
return False
#process the bank piece
if "$" in bank_piece:
bank_piece = bank_piece.replace("$", "0x")
#check characters for validity?
for c in bank_piece.replace("x", ""):
if c not in valid:
return False
bank = int(bank_piece, 16)
#process the offset piece
if "$" in offset_piece:
offset_piece = offset_piece.replace("$", "0x")
#check characters for validity?
for c in offset_piece.replace("x", ""):
if c not in valid:
return False
offset = int(offset_piece, 16)
#filter out blanks/duds
elif token in ["$", "0x", "x"]:
return False
#can't have both "$" and "x" in the number
elif "$" in token and "x" in token:
return False
elif "x" in token and not "0x" in token: #it should be 0x
return False
elif "$" in token and not "x" in token:
token = token.replace("$", "0x")
offset = int(token, 16)
bank = calculate_bank(offset)
elif "0x" in token and not "$" in token:
offset = int(token, 16)
bank = calculate_bank(offset)
else: #might just be "1" at this point
token = token.lower()
#check if there are bad characters
for c in token:
if c not in valid:
return False
offset = int(token, 16)
bank = calculate_bank(offset)
if offset == None and bank == None:
return False
returnable["bank"] = bank
returnable["offset"] = offset
returnable["address"] = calculate_pointer(offset, bank=bank)
return True
def line_has_label(line):
"""returns True if the line has an asm label"""
if not isinstance(line, str):
raise Exception, "can't check this type of object"
line = line.rstrip(" ").lstrip(" ")
line = remove_quoted_text(line)
if ";" in line:
line = line.split(";")[0]
if 0 <= len(line) <= 1:
return False
if ":" not in line:
return False
if line[0] == ";":
return False
if line[0] == "\"":
return False
if "::" in line:
return False
return True
def get_label_from_line(line):
"""returns the label from the line"""
#check if the line has a label
if not line_has_label(line):
return None
#split up the line
label = line.split(":")[0]
return label
def find_labels_without_addresses():
"""scans the asm source and finds labels that are unmarked"""
without_addresses = []
for (line_number, line) in enumerate(asm):
if line_has_label(line):
label = get_label_from_line(line)
if not line_has_comment_address(line):
without_addresses.append({"line_number": line_number, "line": line, "label": label})
return without_addresses
label_errors = ""
def get_labels_between(start_line_id, end_line_id, bank_id):
labels = []
#label = {
# "line_number": 15,
# "bank": 32,
# "label": "PalletTownText1",
# "offset": 0x5315,
# "address": 0x75315,
#}
sublines = asm[start_line_id : end_line_id + 1]
for (current_line_offset, line) in enumerate(sublines):
#skip lines without labels
if not line_has_label(line): continue
#reset some variables
line_id = start_line_id + current_line_offset
line_label = get_label_from_line(line)
address = None
offset = None
#setup a place to store return values from line_has_comment_address
returnable = {}
#get the address from the comment
has_comment = line_has_comment_address(line, returnable=returnable)
#skip this line if it has no address in the comment
if not has_comment: continue
#parse data from line_has_comment_address
address = returnable["address"]
bank = returnable["bank"]
offset = returnable["offset"]
#dump all this info into a single structure
label = {
"line_number": line_id,
"bank": bank,
"label": line_label,
"offset": offset,
"address": address,
}
#store this structure
labels.append(label)
return labels
def scan_for_predefined_labels():
"""looks through the asm file for labels at specific addresses,
this relies on the label having its address after. ex:
ViridianCity_h: ; 0x18357 to 0x18384 (45 bytes) (bank=6) (id=1)
PalletTownText1: ; 4F96 0x18f96
ViridianCityText1: ; 0x19102
It would be more productive to use rgbasm to spit out all label
addresses, but faster to write this script. rgbasm would be able
to grab all label addresses better than this script..
"""
bank_intervals = {}
all_labels = []
#figure out line numbers for each bank
for bank_id in range(0x7F+1):
abbreviation = ("%.x" % (bank_id)).upper()
abbreviation_next = ("%.x" % (bank_id+1)).upper()
if bank_id == 0:
abbreviation = "0"
abbreviation_next = "1"
start_line_id = index(asm, lambda line: "\"bank" + abbreviation + "\"" in line)
if bank_id != 0x2c:
end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next + "\"" in line)
else:
end_line_id = len(asm) - 1
print "bank" + abbreviation + " starts at " + str(start_line_id) + " to " + str(end_line_id)
bank_intervals[bank_id] = {
"start": start_line_id,
"end": end_line_id,
}
for bank_id in bank_intervals.keys():
bank_data = bank_intervals[bank_id]
start_line_id = bank_data["start"]
end_line_id = bank_data["end"]
labels = get_labels_between(start_line_id, end_line_id, bank_id)
#bank_intervals[bank_id]["labels"] = labels
all_labels.extend(labels)
write_all_labels(all_labels)
return all_labels
#### generic testing #### #### generic testing ####
class TestCram(unittest.TestCase): class TestCram(unittest.TestCase):
@ -4615,6 +5072,69 @@ class TestAsmList(unittest.TestCase):
self.assertEquals(len(base), asm.length()) self.assertEquals(len(base), asm.length())
self.assertEquals(len(base), len(list(asm))) self.assertEquals(len(base), len(list(asm)))
self.assertEquals(len(asm), asm.length()) self.assertEquals(len(asm), asm.length())
def test_remove_quoted_text(self):
x = remove_quoted_text
self.assertEqual(x("hello world"), "hello world")
self.assertEqual(x("hello \"world\""), "hello ")
input = 'hello world "testing 123"'
self.assertNotEqual(x(input), input)
input = "hello world 'testing 123'"
self.assertNotEqual(x(input), input)
self.failIf("testing" in x(input))
def test_line_has_comment_address(self):
x = line_has_comment_address
self.assertFalse(x(""))
self.assertFalse(x(";"))
self.assertFalse(x(";;;"))
self.assertFalse(x(":;"))
self.assertFalse(x(":;:"))
self.assertFalse(x(";:"))
self.assertFalse(x(" "))
self.assertFalse(x("".join(" " * 5)))
self.assertFalse(x("".join(" " * 10)))
self.assertFalse(x("hello world"))
self.assertFalse(x("hello_world"))
self.assertFalse(x("hello_world:"))
self.assertFalse(x("hello_world:;"))
self.assertFalse(x("hello_world: ;"))
self.assertFalse(x("hello_world: ; "))
self.assertFalse(x("hello_world: ;" + "".join(" " * 5)))
self.assertFalse(x("hello_world: ;" + "".join(" " * 10)))
self.assertTrue(x(";1"))
self.assertTrue(x(";F"))
self.assertTrue(x(";$00FF"))
self.assertTrue(x(";0x00FF"))
self.assertTrue(x("; 0x00FF"))
self.assertTrue(x(";$3:$300"))
self.assertTrue(x(";0x3:$300"))
self.assertTrue(x(";$3:0x300"))
self.assertTrue(x(";3:300"))
self.assertTrue(x(";3:FFAA"))
self.assertFalse(x('hello world "how are you today;0x1"'))
self.assertTrue(x('hello world "how are you today:0x1";1'))
def test_line_has_label(self):
x = line_has_label
self.assertTrue(x("hi:"))
self.assertTrue(x("Hello: "))
self.assertTrue(x("MyLabel: ; test xyz"))
self.assertFalse(x(":"))
self.assertFalse(x(";HelloWorld:"))
self.assertFalse(x("::::"))
self.assertFalse(x(":;:;:;:::"))
def test_get_label_from_line(self):
x = get_label_from_line
self.assertEqual(x("HelloWorld: "), "HelloWorld")
self.assertEqual(x("HiWorld:"), "HiWorld")
self.assertEqual(x("HiWorld"), None)
def test_find_labels_without_addresses(self):
global asm
asm = ["hello_world: ; 0x1", "hello_world2: ;"]
labels = find_labels_without_addresses()
self.failUnless(labels[0]["label"] == "hello_world2")
asm = ["hello world: ;1", "hello_world: ;2"]
labels = find_labels_without_addresses()
self.failUnless(len(labels) == 0)
asm = None
class TestMapParsing(unittest.TestCase): class TestMapParsing(unittest.TestCase):
#def test_parse_warp_bytes(self): #def test_parse_warp_bytes(self):
# pass #or raise NotImplementedError, bryan_message # pass #or raise NotImplementedError, bryan_message