#!/usr/bin/python2 # # Python functions to read, split and apply patches. # # Copyright (C) 2014-2016 Sebastian Lackner # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA # import collections import difflib import email.header import hashlib import itertools import os import re import shutil import subprocess import sys import tempfile try: from cStringIO import StringIO except ImportError: from io import StringIO _devnull = open(os.devnull, 'wb') class PatchParserError(RuntimeError): """Unable to parse patch file - either an unimplemented feature, or corrupted patch.""" pass class PatchApplyError(RuntimeError): """Failed to apply/merge patch.""" pass class PatchDiffError(RuntimeError): """Failed to compute diff.""" pass class CParserError(RuntimeError): """Unable to parse C source.""" pass class PatchObject(object): def __init__(self, filename, header): self.patch_author = header.get('author', None) self.patch_email = header.get('email', None) self.patch_subject = header.get('subject', None) self.patch_revision = header.get('revision', 1) self.signed_off_by = header.get('signedoffby', []) self.filename = filename self.offset_begin = None self.offset_end = None self.is_binary = False self.oldname = None self.newname = None self.modified_file = None self.oldsha1 = None self.newsha1 = None self.newmode = None def read_chunks(self): """Iterates over arbitrary sized chunks of this patch.""" assert self.offset_end >= self.offset_begin with open(self.filename) as fp: fp.seek(self.offset_begin) i = self.offset_end - self.offset_begin while i > 0: buf = fp.read(16384 if i > 16384 else i) if buf == "": raise IOError("Unable to extract patch.") yield buf i -= len(buf) def read(self): """Return the full patch as a string.""" return "".join(chunk for chunk in self.read_chunks()) class _PatchReader(object): def __init__(self, filename, fp=None): self.filename = filename self.fp = fp if fp is not None else open(filename) self.peeked = None def close(self): self.fp.close() def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def seek(self, pos): """Change the file cursor position.""" self.fp.seek(pos) self.peeked = None def tell(self): """Return the current file cursor position.""" if self.peeked is None: return self.fp.tell() return self.peeked[0] def peek(self): """Read one line without changing the file cursor.""" if self.peeked is None: pos = self.fp.tell() tmp = self.fp.readline() if len(tmp) == 0: return None self.peeked = (pos, tmp) return self.peeked[1] def read(self): """Read one line from the file, and move the file cursor to the next line.""" if self.peeked is None: tmp = self.fp.readline() if len(tmp) == 0: return None return tmp tmp, self.peeked = self.peeked, None return tmp[1] def read_hunk(self): """Read one hunk from a patch file.""" line = self.peek() if line is None or not line.startswith("@@ -"): return None r = re.match("^@@ -(([0-9]+),)?([0-9]+) \+(([0-9]+),)?([0-9]+) @@", line) if not r: raise PatchParserError("Unable to parse hunk header '%s'." % line) srcpos = max(int(r.group(2)) - 1, 0) if r.group(2) else 0 dstpos = max(int(r.group(5)) - 1, 0) if r.group(5) else 0 srclines, dstlines = int(r.group(3)), int(r.group(6)) if srclines <= 0 and dstlines <= 0: raise PatchParserError("Empty hunk doesn't make sense.") self.read() srcdata = [] dstdata = [] try: while srclines > 0 or dstlines > 0: line = self.read().rstrip("\n") if line[0] == " ": if srclines == 0 or dstlines == 0: raise PatchParserError("Corrupted patch.") srcdata.append(line[1:]) dstdata.append(line[1:]) srclines -= 1 dstlines -= 1 elif line[0] == "-": if srclines == 0: raise PatchParserError("Corrupted patch.") srcdata.append(line[1:]) srclines -= 1 elif line[0] == "+": if dstlines == 0: raise PatchParserError("Corrupted patch.") dstdata.append(line[1:]) dstlines -= 1 elif line[0] == "\\": pass # ignore else: raise PatchParserError("Unexpected line in hunk.") except IndexError: # triggered by ""[0] raise PatchParserError("Truncated patch.") while True: line = self.peek() if line is None or not line.startswith("\\ "): break self.read() return (srcpos, srcdata, dstpos, dstdata) def _read_single_patch(fp, header, oldname=None, newname=None): """Internal function to read a single patch from a file.""" patch = PatchObject(fp.filename, header) patch.offset_begin = fp.tell() patch.oldname = oldname patch.newname = newname # Skip over initial diff --git header line = fp.peek() if line.startswith("diff --git "): fp.read() # Read header while True: line = fp.peek() if line is None: break elif line.startswith("--- "): patch.oldname = line[4:].strip() elif line.startswith("+++ "): patch.newname = line[4:].strip() elif line.startswith("old mode") or line.startswith("deleted file mode"): pass # ignore elif line.startswith("new mode "): patch.newmode = line[9:].strip() elif line.startswith("new file mode "): patch.newmode = line[14:].strip() elif line.startswith("new mode") or line.startswith("new file mode"): raise PatchParserError("Unable to parse header line '%s'." % line) elif line.startswith("copy from") or line.startswith("copy to"): raise NotImplementedError("Patch copy header not implemented yet.") elif line.startswith("rename "): raise NotImplementedError("Patch rename header not implemented yet.") elif line.startswith("similarity index") or line.startswith("dissimilarity index"): pass # ignore elif line.startswith("index "): r = re.match("^index ([a-fA-F0-9]*)\.\.([a-fA-F0-9]*)", line) if not r: raise PatchParserError("Unable to parse index header line '%s'." % line) patch.oldsha1, patch.newsha1 = r.group(1), r.group(2) else: break fp.read() if patch.oldname is None or patch.newname is None: raise PatchParserError("Missing old or new name.") elif patch.oldname == "/dev/null" and patch.newname == "/dev/null": raise PatchParserError("Old and new name is /dev/null?") if patch.oldname.startswith("a/"): patch.oldname = patch.oldname[2:] elif patch.oldname != "/dev/null": raise PatchParserError("Old name in patch doesn't start with a/.") if patch.newname.startswith("b/"): patch.newname = patch.newname[2:] elif patch.newname != "/dev/null": raise PatchParserError("New name in patch doesn't start with b/.") if patch.newname != "/dev/null": patch.modified_file = patch.newname else: patch.modified_file = patch.oldname # Decide between binary and textual patch if line is None or line.startswith("diff --git ") or line.startswith("--- "): if oldname != newname: raise PatchParserError("Stripped old- and new name doesn't match.") elif line.startswith("@@ -"): while fp.read_hunk() is not None: pass elif line.rstrip() == "GIT binary patch": if patch.oldsha1 is None or patch.newsha1 is None: raise PatchParserError("Missing index header, sha1 sums required for binary patch.") elif patch.oldname != patch.newname: raise PatchParserError("Stripped old- and new name doesn't match for binary patch.") fp.read() line = fp.read() if line is None: raise PatchParserError("Unexpected end of file.") r = re.match("^(literal|delta) ([0-9]+)", line) if not r: raise NotImplementedError("Only literal/delta patches are supported.") patch.is_binary = True # Skip over patch data while True: line = fp.read() if line is None or line.strip() == "": break else: raise PatchParserError("Unknown patch format.") patch.offset_end = fp.tell() return patch def _parse_author(author): if sys.version_info[0] > 2: author = str(email.header.make_header(email.header.decode_header(author))) else: author = ' '.join([data.decode(format or 'utf-8').encode('utf-8') for \ data, format in email.header.decode_header(author)]) r = re.match("\"?([^\"]*)\"? <(.*)>", author) if r is None: raise NotImplementedError("Failed to parse From - header.") return r.group(1).strip(), r.group(2).strip() def _parse_subject(subject): version = "(v|try|rev|take) *([0-9]+)" subject = subject.strip() if subject.endswith("."): subject = subject[:-1] r = re.match("^\\[PATCH([^]]*)\\](.*)$", subject, re.IGNORECASE) if r is not None: subject = r.group(2).strip() r = re.search(version, r.group(1), re.IGNORECASE) if r is not None: return subject, int(r.group(2)) r = re.match("^(.*)\\(%s\\)$" % version, subject, re.IGNORECASE) if r is not None: return r.group(1).strip(), int(r.group(3)) r = re.match("^(.*)\\[%s\\]$" % version, subject, re.IGNORECASE) if r is not None: return r.group(1).strip(), int(r.group(3)) r = re.match("^(.*)[.,] +%s$" % version, subject, re.IGNORECASE) if r is not None: return r.group(1).strip(), int(r.group(3)) r = re.match("^([^:]+) %s: (.*)$" % version, subject, re.IGNORECASE) if r is not None: return "%s: %s" % (r.group(1), r.group(4)), int(r.group(3)) r = re.match("^(.*) +%s$" % version, subject, re.IGNORECASE) if r is not None: return r.group(1).strip(), int(r.group(3)) r = re.match("^(.*)\\(resend\\)$", subject, re.IGNORECASE) if r is not None: return r.group(1).strip(), 1 return subject, 1 def read_patch(filename, fp=None): """Iterates over all patches contained in a file, and returns PatchObject objects.""" header = {} with _PatchReader(filename, fp) as fp: while True: line = fp.peek() if line is None: break elif line.startswith("From: "): header['author'], header['email'] = _parse_author(line[6:]) header.pop('signedoffby', None) fp.read() elif line.startswith("Subject: "): subject = line[9:].rstrip("\r\n") fp.read() while True: line = fp.peek() if not line.startswith(" "): break subject += line.rstrip("\r\n") fp.read() subject, revision = _parse_subject(subject) if not subject.endswith("."): subject += "." subject = re.sub('^([^:]*: *)([a-z])', lambda x: "%s%s" % (x.group(1), x.group(2).upper()), subject, 1) header['subject'], header['revision'] = subject, revision header.pop('signedoffby', None) elif line.startswith("Signed-off-by: "): if 'signedoffby' not in header: header['signedoffby'] = [] header['signedoffby'].append(_parse_author(line[15:])) fp.read() elif line.startswith("diff --git "): tmp = line.strip().split(" ") if len(tmp) != 4: raise PatchParserError("Unable to parse git diff header line '%s'." % line) yield _read_single_patch(fp, header, tmp[2].strip(), tmp[3].strip()) elif line.startswith("--- "): yield _read_single_patch(fp, header) elif line.startswith("@@ -") or line.startswith("+++ "): raise PatchParserError("Patch didn't start with a git or diff header.") else: fp.read() def apply_patch(original, patchfile, reverse=False, fuzz=2): """Apply a patch with optional fuzz - uses the commandline 'patch' utility.""" result = tempfile.NamedTemporaryFile(mode='w+', delete=False) try: # We open the file again to avoid race-conditions with multithreaded reads with open(original.name) as fp: shutil.copyfileobj(fp, result) result.close() cmdline = ["patch", "--no-backup-if-mismatch", "--force", "--silent", "-r", "-"] if reverse: cmdline.append("--reverse") if fuzz != 2: cmdline.append("--fuzz=%d" % fuzz) cmdline += [result.name, patchfile.name] exitcode = subprocess.call(cmdline, stdout=_devnull, stderr=_devnull) if exitcode != 0: raise PatchApplyError("Failed to apply patch (exitcode %d)." % exitcode) # Hack - we can't keep the file open while patching ('patch' might rename/replace # the file), so create a new _TemporaryFileWrapper object for the existing path. return tempfile._TemporaryFileWrapper(file=open(result.name, 'r+'), \ name=result.name, delete=True) except: os.unlink(result.name) raise def _preprocess_source(fp): """Simple C preprocessor to determine where we can safely add #ifdef instructions.""" _re_state0 = re.compile("(\"|/[/*])") _re_state1 = re.compile("(\\\\\"|\\\\\\\\|\")") _re_state2 = re.compile("\\*/") # We need to read the original file, and figure out where lines can be splitted lines = [] for line in fp: lines.append(line.rstrip("\n")) split = set([0]) state = 0 i = 0 while i < len(lines): # Read a full line (and handle line continuation) line = lines[i] i += 1 while line.endswith("\\"): if i >= len(lines): raise CParserError("Unexpected end of file.") line = line[:-1] + lines[i] i += 1 # To find out where we can add our #ifdef tags we use a simple # statemachine. This allows finding the beginning of a multiline # instruction or comment. j = 0 while True: # State 0: No context if state == 0: match = _re_state0.search(line, j) if match is None: break if match.group(0) == "\"": state = 1 # Begin of string elif match.group(0) == "/*": state = 2 # Begin of comment elif match.group(0) == "//": break # Rest of the line is a comment, which can be safely ignored else: assert 0 # State 1: Inside of string elif state == 1: match = _re_state1.search(line, j) if match is None: raise CParserError("Line ended in the middle of a string.") if match.group(0) == "\"": state = 0 # End of string elif match.group(0) != "\\\"" and match.group(0) != "\\\\": assert 0 # State 2: Multiline comment elif state == 2: match = _re_state2.search(line, j) if match is None: break if match.group(0) == "*/": state = 0 # End of comment else: assert 0 else: raise CParserError("Internal state error.") j = match.end() # Only in state 0 (no context) we can split here if state == 0: split.add(i) # Ensure that the last comment is properly terminated if state != 0: raise CParserError("Unexpected end of file.") return lines, split def generate_ifdef_patch(original, patched, ifdef): """Generate a patch which adds #ifdef where necessary to keep both the original and patched version.""" # # The basic of idea of this algorithm is as following: # # (1) determine diff between original file and patched file # (2) run the preprocessor, to determine where #ifdefs can be safely added # (3) use diff and preprocessor information to create a merged version containing #ifdefs # (4) create another diff to apply the changes on the patched version # with tempfile.NamedTemporaryFile(mode='w+') as diff: exitcode = subprocess.call(["diff", "-u", original.name, patched.name], stdout=diff, stderr=_devnull) if exitcode == 0: return None elif exitcode != 1: raise PatchDiffError("Failed to compute diff (exitcode %d)." % exitcode) # Preprocess the original C source original.seek(0) lines, split = _preprocess_source(original) # Parse the created diff file fp = _PatchReader(diff.name, diff) fp.seek(0) # We expect this output format from 'diff', if this is not the case things might go wrong. line = fp.read() assert line.startswith("--- ") line = fp.read() assert line.startswith("+++ ") hunks = [] while fp.peek() is not None: srcpos, srcdata, dstpos, dstdata = fp.read_hunk() # Ensure that the patch would really apply in practice if lines[srcpos:srcpos + len(srcdata)] != srcdata: raise PatchParserError("Patch failed to apply.") # Strip common lines from the beginning and end while len(srcdata) > 0 and len(dstdata) > 0 and \ srcdata[0] == dstdata[0]: srcdata.pop(0) dstdata.pop(0) srcpos += 1 dstpos += 1 while len(srcdata) > 0 and len(dstdata) > 0 and \ srcdata[-1] == dstdata[-1]: srcdata.pop() dstdata.pop() # Ensure that diff generated valid output assert len(srcdata) > 0 or len(dstdata) > 0 # If this is the first hunk, then check if we have to extend it at the beginning if len(hunks) == 0: assert srcpos == dstpos while srcpos > 0 and srcpos not in split: srcpos -= 1 dstpos -= 1 srcdata.insert(0, lines[srcpos]) dstdata.insert(0, lines[srcpos]) hunks.append((srcpos, dstpos, srcdata, dstdata)) # Check if we can merge with the previous hunk else: prev_srcpos, prev_dstpos, prev_srcdata, prev_dstdata = hunks[-1] prev_endpos = prev_srcpos + len(prev_srcdata) found = 0 for i in xrange(prev_endpos, srcpos): if i in split: found += 1 # At least two possible splitting positions inbetween if found >= 2: while prev_endpos not in split: prev_srcdata.append(lines[prev_endpos]) prev_dstdata.append(lines[prev_endpos]) prev_endpos += 1 while srcpos not in split: srcpos -= 1 srcdata.insert(0, lines[srcpos]) dstdata.insert(0, lines[srcpos]) hunks.append((srcpos, dstpos, srcdata, dstdata)) # Merge hunks else: while prev_endpos < srcpos: prev_srcdata.append(lines[prev_endpos]) prev_dstdata.append(lines[prev_endpos]) prev_endpos += 1 assert prev_dstpos + len(prev_dstdata) == dstpos prev_srcdata.extend(srcdata) prev_dstdata.extend(dstdata) # Ready with this hunk pass # We might have to extend the last hunk if len(hunks): prev_srcpos, prev_dstpos, prev_srcdata, prev_dstdata = hunks[-1] prev_endpos = prev_srcpos + len(prev_srcdata) while prev_endpos < len(lines) and prev_endpos not in split: prev_srcdata.append(lines[prev_endpos]) prev_dstdata.append(lines[prev_endpos]) prev_endpos += 1 # Generate resulting file with #ifdefs with tempfile.NamedTemporaryFile(mode='w+') as intermediate: pos = 0 while len(hunks): srcpos, dstpos, srcdata, dstdata = hunks.pop(0) if pos < srcpos: intermediate.write("\n".join(lines[pos:srcpos])) intermediate.write("\n") if len(srcdata) and len(dstdata): intermediate.write("#if defined(%s)\n" % ifdef) intermediate.write("\n".join(dstdata)) intermediate.write("\n#else /* %s */\n" % ifdef) intermediate.write("\n".join(srcdata)) intermediate.write("\n#endif /* %s */\n" % ifdef) elif len(srcdata): intermediate.write("#if !defined(%s)\n" % ifdef) intermediate.write("\n".join(srcdata)) intermediate.write("\n#endif /* %s */\n" % ifdef) elif len(dstdata): intermediate.write("#if defined(%s)\n" % ifdef) intermediate.write("\n".join(dstdata)) intermediate.write("\n#endif /* %s */\n" % ifdef) else: assert 0 pos = srcpos + len(srcdata) if pos < len(lines): intermediate.write("\n".join(lines[pos:])) intermediate.write("\n") intermediate.flush() # Now we can finally compute the diff between the patched file and our intermediate file diff = tempfile.NamedTemporaryFile(mode='w+') exitcode = subprocess.call(["git", "diff", "--no-index", patched.name, intermediate.name], stdout=diff, stderr=_devnull) if exitcode != 1: # exitcode 0 cannot (=shouldn't) happen in this situation raise PatchDiffError("Failed to compute diff (exitcode %d)." % exitcode) diff.flush() diff.seek(0) # We expect this output format from 'git diff', if this is not the case things might go wrong. line = diff.readline() assert line.startswith("diff --git ") line = diff.readline() assert line.startswith("index ") line = diff.readline() assert line.startswith("--- ") line = diff.readline() assert line.startswith("+++ ") # Return the final diff return diff if __name__ == "__main__": import unittest # Basic tests for _parse_author() and _parse_subject() class PatchParserTests(unittest.TestCase): def test_author(self): author = _parse_author("Author Name ") self.assertEqual(author, ("Author Name", "author@email.com")) author = _parse_author("=?UTF-8?q?Author=20Name?= ") self.assertEqual(author, ("Author Name", "author@email.com")) def test_subject(self): subject = _parse_subject("[PATCH v3] component: Subject.") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject (v3).") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject (try 3).") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject (take 3).") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject (rev 3).") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject [v3].") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject, v3.") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject v3.") self.assertEqual(subject, ("component: Subject", 3)) subject = _parse_subject("[PATCH] component: Subject (resend).") self.assertEqual(subject, ("component: Subject", 1)) # Basic tests for read_patch() class PatchReaderTests(unittest.TestCase): def test_simple(self): with open("tests/simple.patch") as fp: source = fp.read().split("\n") # Test formatted git patch with author and subject patchfile = tempfile.NamedTemporaryFile(mode='w+') patchfile.write("\n".join(source)) patchfile.flush() patches = list(read_patch(patchfile.name)) self.assertEqual(len(patches), 1) self.assertEqual(patches[0].patch_author, "Author Name") self.assertEqual(patches[0].patch_email, "author@email.com") self.assertEqual(patches[0].patch_subject, "component: Replace arg1 with arg2.") self.assertEqual(patches[0].patch_revision, 3) self.assertEqual(patches[0].signed_off_by, [("Author Name", "author@email.com"), ("Other Developer", "other@email.com")]) self.assertEqual(patches[0].filename, patchfile.name) self.assertEqual(patches[0].is_binary, False) self.assertEqual(patches[0].modified_file, "test.txt") lines = patches[0].read().rstrip("\n").split("\n") self.assertEqual(lines, source[10:23]) # Test with git diff del source[0:10] self.assertTrue(source[0].startswith("diff --git")) patchfile = tempfile.NamedTemporaryFile(mode='w+') patchfile.write("\n".join(source)) patchfile.flush() patches = list(read_patch(patchfile.name)) self.assertEqual(len(patches), 1) self.assertEqual(patches[0].patch_author, None) self.assertEqual(patches[0].patch_email, None) self.assertEqual(patches[0].patch_subject, None) self.assertEqual(patches[0].patch_revision, 1) self.assertEqual(patches[0].signed_off_by, []) self.assertEqual(patches[0].filename, patchfile.name) self.assertEqual(patches[0].is_binary, False) self.assertEqual(patches[0].modified_file, "test.txt") lines = patches[0].read().rstrip("\n").split("\n") self.assertEqual(lines, source[:13]) # Test with unified diff del source[0:2] self.assertTrue(source[0].startswith("---")) patchfile = tempfile.NamedTemporaryFile(mode='w+') patchfile.write("\n".join(source)) patchfile.flush() patches = list(read_patch(patchfile.name)) self.assertEqual(len(patches), 1) self.assertEqual(patches[0].patch_author, None) self.assertEqual(patches[0].patch_email, None) self.assertEqual(patches[0].patch_subject, None) self.assertEqual(patches[0].patch_revision, 1) self.assertEqual(patches[0].signed_off_by, []) self.assertEqual(patches[0].filename, patchfile.name) self.assertEqual(patches[0].is_binary, False) self.assertEqual(patches[0].modified_file, "test.txt") lines = patches[0].read().rstrip("\n").split("\n") self.assertEqual(lines, source[:11]) # Test with StringIO buffer fp = StringIO("\n".join(source)) patches = list(read_patch("unknown.patch", fp)) self.assertEqual(len(patches), 1) self.assertEqual(patches[0].patch_author, None) self.assertEqual(patches[0].patch_email, None) self.assertEqual(patches[0].patch_subject, None) self.assertEqual(patches[0].patch_revision, 1) self.assertEqual(patches[0].signed_off_by, []) self.assertEqual(patches[0].filename, "unknown.patch") self.assertEqual(patches[0].is_binary, False) self.assertEqual(patches[0].modified_file, "test.txt") def test_multi(self): with open("tests/multi.patch") as fp: source = fp.read().split("\n") patchfile = tempfile.NamedTemporaryFile(mode='w+') patchfile.write("\n".join(source)) patchfile.flush() patches = list(read_patch(patchfile.name)) self.assertEqual(len(patches), 3) self.assertEqual(patches[0].patch_author, "Author Name") self.assertEqual(patches[0].patch_email, "author@email.com") self.assertEqual(patches[0].patch_subject, "component: Replace arg1 with arg2.") self.assertEqual(patches[0].patch_revision, 3) self.assertEqual(patches[0].signed_off_by, [("Author Name", "author@email.com"), ("Other Developer", "other@email.com")]) self.assertEqual(patches[0].filename, patchfile.name) self.assertEqual(patches[0].is_binary, False) self.assertEqual(patches[0].modified_file, "other_test.txt") lines = patches[0].read().rstrip("\n").split("\n") self.assertEqual(lines, source[11:24]) self.assertEqual(patches[1].patch_author, "Author Name") self.assertEqual(patches[1].patch_email, "author@email.com") self.assertEqual(patches[1].patch_subject, "component: Replace arg1 with arg2.") self.assertEqual(patches[1].patch_revision, 3) self.assertEqual(patches[1].signed_off_by, [("Author Name", "author@email.com"), ("Other Developer", "other@email.com")]) self.assertEqual(patches[1].filename, patchfile.name) self.assertEqual(patches[1].is_binary, False) self.assertEqual(patches[1].modified_file, "test.txt") lines = patches[1].read().rstrip("\n").split("\n") self.assertEqual(lines, source[24:46]) self.assertEqual(patches[2].patch_author, "Other Developer") self.assertEqual(patches[2].patch_email, "other@email.com") self.assertEqual(patches[2].patch_subject, "component: Replace arg2 with arg3.") self.assertEqual(patches[2].patch_revision, 4) self.assertEqual(patches[2].signed_off_by, [("Other Developer", "other@email.com")]) self.assertEqual(patches[2].filename, patchfile.name) self.assertEqual(patches[2].is_binary, False) self.assertEqual(patches[2].modified_file, "test.txt") lines = patches[2].read().rstrip("\n").split("\n") self.assertEqual(lines, source[58:71]) # Basic tests for apply_patch() class PatchApplyTests(unittest.TestCase): def test_apply(self): source = ["line1();", "line2();", "line3();", "function(arg1);", "line5();", "line6();", "line7();"] original = tempfile.NamedTemporaryFile(mode='w+') original.write("\n".join(source + [""])) original.flush() source = ["@@ -1,7 +1,7 @@", " line1();", " line2();", " line3();", "-function(arg1);", "+function(arg2);", " line5();", " line6();", " line7();"] patchfile = tempfile.NamedTemporaryFile(mode='w+') patchfile.write("\n".join(source + [""])) patchfile.flush() expected = ["line1();", "line2();", "line3();", "function(arg2);", "line5();", "line6();", "line7();"] result = apply_patch(original, patchfile, fuzz=0) lines = result.read().rstrip("\n").split("\n") self.assertEqual(lines, expected) expected = ["line1();", "line2();", "line3();", "function(arg1);", "line5();", "line6();", "line7();"] result = apply_patch(result, patchfile, reverse=True, fuzz=0) lines = result.read().rstrip("\n").split("\n") self.assertEqual(lines, expected) # Basic tests for _preprocess_source() class PreprocessorTests(unittest.TestCase): def test_preprocessor(self): source = ["int a; // comment 1", "int b; // comment 2 \\", " comment 3 \\", " comment 4", "int c; // comment with \"quotes\"", "int d; // comment with /* c++ comment */", "int e; /* multi \\", " line", " comment */", "char *x = \"\\\\\";", "char *y = \"abc\\\"def\";", "char *z = \"multi\" \\", " \"line\"", " \"string\";"] lines, split = _preprocess_source(source) self.assertEqual(lines, source) self.assertEqual(split, set([0, 1, 4, 5, 6, 9, 10, 11, 13, 14])) # Basic tests for generate_ifdef_patch() class GenerateIfdefPatchTests(unittest.TestCase): def test_ifdefined(self): source = ["line1();", "line2();", "line3();", "function(arg1, \\", " arg2, \\", " arg3);", "line5();", "line6();", "line7();"] source1 = tempfile.NamedTemporaryFile(mode='w+') source1.write("\n".join(source + [""])) source1.flush() source = ["line1();", "line2();", "line3();", "function(arg1, \\", " new_arg2, \\", " arg3);", "line5();", "line6();", "line7();"] source2 = tempfile.NamedTemporaryFile(mode='w+') source2.write("\n".join(source + [""])) source2.flush() diff = generate_ifdef_patch(source1, source1, "PATCHED") self.assertEqual(diff, None) diff = generate_ifdef_patch(source2, source2, "PATCHED") self.assertEqual(diff, None) expected = ["@@ -1,9 +1,15 @@", " line1();", " line2();", " line3();", "+#if defined(PATCHED)", " function(arg1, \\", " new_arg2, \\", " arg3);", "+#else /* PATCHED */", "+function(arg1, \\", "+ arg2, \\", "+ arg3);", "+#endif /* PATCHED */", " line5();", " line6();", " line7();"] diff = generate_ifdef_patch(source1, source2, "PATCHED") lines = diff.read().rstrip("\n").split("\n") self.assertEqual(lines, expected) expected = ["@@ -1,9 +1,15 @@", " line1();", " line2();", " line3();", "+#if defined(PATCHED)", " function(arg1, \\", " arg2, \\", " arg3);", "+#else /* PATCHED */", "+function(arg1, \\", "+ new_arg2, \\", "+ arg3);", "+#endif /* PATCHED */", " line5();", " line6();", " line7();"] diff = generate_ifdef_patch(source2, source1, "PATCHED") lines = diff.read().rstrip("\n").split("\n") self.assertEqual(lines, expected) unittest.main()