#!/usr/bin/python2
#
# Python functions to read, split and apply patches.
#
# Copyright (C) 2014-2017 Sebastian Lackner
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
#

import collections
import difflib
import email.header
import hashlib
import itertools
import os
import re
import shutil
import stat
import subprocess
import sys
import tempfile

try:
    from cStringIO import StringIO
except ImportError:
    from io import StringIO

_devnull = open(os.devnull, 'wb')

class PatchParserError(RuntimeError):
    """Unable to parse patch file - either an unimplemented feature, or corrupted patch."""
    pass

class PatchApplyError(RuntimeError):
    """Failed to apply/merge patch."""
    pass

class PatchDiffError(RuntimeError):
    """Failed to compute diff."""
    pass

class CParserError(RuntimeError):
    """Unable to parse C source."""
    pass

class PatchObject(object):
    def __init__(self, filename, header):
        self.patch_author       = header.get('author', None)
        self.patch_email        = header.get('email', None)
        self.patch_subject      = header.get('subject', None)
        self.patch_revision     = header.get('revision', 1)
        self.signed_off_by      = header.get('signedoffby', [])

        self.filename           = filename
        self.offset_begin       = None
        self.offset_end         = None
        self.is_binary          = False

        self.oldname            = None
        self.newname            = None
        self.modified_file      = None

        self.oldsha1            = None
        self.newsha1            = None
        self.newmode            = None

    def read_chunks(self):
        """Iterates over arbitrary sized chunks of this patch."""
        assert self.offset_end >= self.offset_begin
        with open(self.filename) as fp:
            fp.seek(self.offset_begin)
            i = self.offset_end - self.offset_begin
            while i > 0:
                buf = fp.read(16384 if i > 16384 else i)
                if buf == "": raise IOError("Unable to extract patch.")
                yield buf
                i -= len(buf)

    def read(self):
        """Return the full patch as a string."""
        return "".join(chunk for chunk in self.read_chunks())

class _PatchReader(object):
    def __init__(self, filename, fp=None):
        self.filename = filename
        self.fp       = fp if fp is not None else open(filename)
        self.peeked   = None

    def close(self):
        self.fp.close()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()

    def seek(self, pos):
        """Change the file cursor position."""
        self.fp.seek(pos)
        self.peeked = None

    def tell(self):
        """Return the current file cursor position."""
        if self.peeked is None:
            return self.fp.tell()
        return self.peeked[0]

    def peek(self):
        """Read one line without changing the file cursor."""
        if self.peeked is None:
            pos = self.fp.tell()
            tmp = self.fp.readline()
            if len(tmp) == 0: return None
            self.peeked = (pos, tmp)
        return self.peeked[1]

    def read(self):
        """Read one line from the file, and move the file cursor to the next line."""
        if self.peeked is None:
            tmp = self.fp.readline()
            if len(tmp) == 0: return None
            return tmp
        tmp, self.peeked = self.peeked, None
        return tmp[1]

    def read_multiline(self):
        """Read multiline data from a patch file."""
        line = self.read().rstrip("\r\n")
        while True:
            tmp = self.peek()
            if not tmp.startswith(" "): break
            line += tmp.rstrip("\r\n")
            self.read()
        return line

    def read_hunk(self):
        """Read one hunk from a patch file."""
        line = self.peek()
        if line is None or not line.startswith("@@ -"):
            return None

        r = re.match("^@@ -([0-9]+)(,([0-9]+))? \+([0-9]+)(,([0-9]+))? @@", line)
        if not r: raise PatchParserError("Unable to parse hunk header '%s'." % line)
        srcpos   = max(int(r.group(1)) - 1, 0)
        dstpos   = max(int(r.group(4)) - 1, 0)
        srclines = int(r.group(3)) if r.group(3) else 1
        dstlines = int(r.group(6)) if r.group(6) else 1
        if srclines <= 0 and dstlines <= 0:
            raise PatchParserError("Empty hunk doesn't make sense.")
        self.read()

        srcdata = []
        dstdata = []

        try:
            while srclines > 0 or dstlines > 0:
                line = self.read().rstrip("\n")
                if line[0] == " ":
                    if srclines == 0 or dstlines == 0:
                        raise PatchParserError("Corrupted patch.")
                    srcdata.append(line[1:])
                    dstdata.append(line[1:])
                    srclines -= 1
                    dstlines -= 1
                elif line[0] == "-":
                    if srclines == 0:
                        raise PatchParserError("Corrupted patch.")
                    srcdata.append(line[1:])
                    srclines -= 1
                elif line[0] == "+":
                    if dstlines == 0:
                        raise PatchParserError("Corrupted patch.")
                    dstdata.append(line[1:])
                    dstlines -= 1
                elif line[0] == "\\":
                    pass # ignore
                else:
                    raise PatchParserError("Unexpected line in hunk.")
        except IndexError: # triggered by ""[0]
            raise PatchParserError("Truncated patch.")

        while True:
            line = self.peek()
            if line is None or not line.startswith("\\ "): break
            self.read()

        return (srcpos, srcdata, dstpos, dstdata)

def _read_single_patch(fp, header, oldname=None, newname=None):
    """Internal function to read a single patch from a file."""

    patch = PatchObject(fp.filename, header)
    patch.offset_begin = fp.tell()
    patch.oldname = oldname
    patch.newname = newname

    # Skip over initial diff --git header
    line = fp.peek()
    if line.startswith("diff --git "):
        fp.read()

    # Read header
    while True:
        line = fp.peek()
        if line is None:
            break

        elif line.startswith("--- "):
            patch.oldname = line[4:].strip()

        elif line.startswith("+++ "):
            patch.newname = line[4:].strip()

        elif line.startswith("old mode") or line.startswith("deleted file mode"):
            pass # ignore

        elif line.startswith("new mode "):
            patch.newmode = line[9:].strip()

        elif line.startswith("new file mode "):
            patch.newmode = line[14:].strip()

        elif line.startswith("new mode") or line.startswith("new file mode"):
            raise PatchParserError("Unable to parse header line '%s'." % line)

        elif line.startswith("copy from") or line.startswith("copy to"):
            raise NotImplementedError("Patch copy header not implemented yet.")

        elif line.startswith("rename "):
            raise NotImplementedError("Patch rename header not implemented yet.")

        elif line.startswith("similarity index") or line.startswith("dissimilarity index"):
            pass # ignore

        elif line.startswith("index "):
            r = re.match("^index ([a-fA-F0-9]*)\.\.([a-fA-F0-9]*)", line)
            if not r: raise PatchParserError("Unable to parse index header line '%s'." % line)
            patch.oldsha1, patch.newsha1 = r.group(1), r.group(2)

        else:
            break
        fp.read()

    if patch.oldname is None or patch.newname is None:
        raise PatchParserError("Missing old or new name.")
    elif patch.oldname == "/dev/null" and patch.newname == "/dev/null":
        raise PatchParserError("Old and new name is /dev/null?")

    if patch.oldname.startswith("a/"):
        patch.oldname = patch.oldname[2:]
    elif patch.oldname != "/dev/null":
        raise PatchParserError("Old name in patch doesn't start with a/.")

    if patch.newname.startswith("b/"):
        patch.newname = patch.newname[2:]
    elif patch.newname != "/dev/null":
        raise PatchParserError("New name in patch doesn't start with b/.")

    if patch.newname != "/dev/null":
        patch.modified_file = patch.newname
    else:
        patch.modified_file = patch.oldname

    # Decide between binary and textual patch
    if line is None or line.startswith("diff --git ") or line.startswith("--- "):
        if oldname != newname:
            raise PatchParserError("Stripped old- and new name doesn't match.")

    elif line.startswith("@@ -"):
        while fp.read_hunk() is not None:
            pass

    elif line.rstrip() == "GIT binary patch":
        if patch.oldsha1 is None or patch.newsha1 is None:
            raise PatchParserError("Missing index header, sha1 sums required for binary patch.")
        elif patch.oldname != patch.newname:
            raise PatchParserError("Stripped old- and new name doesn't match for binary patch.")
        fp.read()

        line = fp.read()
        if line is None: raise PatchParserError("Unexpected end of file.")
        r = re.match("^(literal|delta) ([0-9]+)", line)
        if not r: raise NotImplementedError("Only literal/delta patches are supported.")
        patch.is_binary = True

        # Skip over patch data
        while True:
            line = fp.read()
            if line is None or line.strip() == "":
                break

    else:
        raise PatchParserError("Unknown patch format.")

    patch.offset_end = fp.tell()
    return patch

def _parse_author(author):
    if sys.version_info[0] > 2:
        author = str(email.header.make_header(email.header.decode_header(author)))
    else:
        author = ' '.join([data.decode(format or 'utf-8').encode('utf-8') for \
                          data, format in email.header.decode_header(author)])
    r =  re.match("\"?([^\"]*)\"? <(.*)>", author)
    if r is None: raise NotImplementedError("Failed to parse From - header.")
    return r.group(1).strip(), r.group(2).strip()

def _parse_subject(subject):
    version = "(v|try|rev|take) *([0-9]+)"
    subject = subject.strip()
    if subject.endswith("."): subject = subject[:-1]
    r = re.match("^\\[PATCH([^]]*)\\](.*)$", subject, re.IGNORECASE)
    if r is not None:
        subject = r.group(2).strip()
        r = re.search(version, r.group(1), re.IGNORECASE)
        if r is not None: return subject, int(r.group(2))
    r = re.match("^(.*)\\(%s\\)$" % version, subject, re.IGNORECASE)
    if r is not None: return r.group(1).strip(), int(r.group(3))
    r = re.match("^(.*)\\[%s\\]$" % version, subject, re.IGNORECASE)
    if r is not None: return r.group(1).strip(), int(r.group(3))
    r = re.match("^(.*)[.,] +%s$" % version, subject, re.IGNORECASE)
    if r is not None: return r.group(1).strip(), int(r.group(3))
    r = re.match("^([^:]+) %s: (.*)$" % version, subject, re.IGNORECASE)
    if r is not None: return "%s: %s" % (r.group(1), r.group(4)), int(r.group(3))
    r = re.match("^(.*) +%s$" % version, subject, re.IGNORECASE)
    if r is not None: return r.group(1).strip(), int(r.group(3))
    r = re.match("^(.*)\\(resend\\)$", subject, re.IGNORECASE)
    if r is not None: return r.group(1).strip(), 1
    return subject, 1

def read_patch(filename, fp=None):
    """Iterates over all patches contained in a file, and returns PatchObject objects."""

    header = {}
    with _PatchReader(filename, fp) as fp:
        while True:
            line = fp.peek()
            if line is None:
                break

            elif line.startswith("From: "):
                author = fp.read_multiline()[6:]
                header['author'], header['email'] = _parse_author(author)
                header.pop('signedoffby', None)

            elif line.startswith("Subject: "):
                subject = fp.read_multiline()[9:]
                subject, revision = _parse_subject(subject)
                if not subject.endswith("."): subject += "."
                subject = re.sub('^([^:]*: *)([a-z])', lambda x: "%s%s" %
                                 (x.group(1), x.group(2).upper()), subject, 1)
                header['subject'], header['revision'] = subject, revision
                header.pop('signedoffby', None)

            elif line.startswith("Signed-off-by: "):
                if 'signedoffby' not in header:
                    header['signedoffby'] = []
                header['signedoffby'].append(_parse_author(line[15:]))
                fp.read()

            elif line.startswith("diff --git "):
                tmp = line.strip().split(" ")
                if len(tmp) != 4: raise PatchParserError("Unable to parse git diff header line '%s'." % line)
                yield _read_single_patch(fp, header, tmp[2].strip(), tmp[3].strip())

            elif line.startswith("--- "):
                yield _read_single_patch(fp, header)

            elif line.startswith("@@ -") or line.startswith("+++ "):
                raise PatchParserError("Patch didn't start with a git or diff header.")

            else:
                fp.read()

def apply_patch(original, patchfile, reverse=False, fuzz=2):
    """Apply a patch with optional fuzz - uses the commandline 'patch' utility."""

    result = tempfile.NamedTemporaryFile(mode='w+', delete=False)
    try:
        # We open the file again to avoid race-conditions with multithreaded reads
        with open(original.name) as fp:
            shutil.copyfileobj(fp, result)
        result.close()

        cmdline = ["patch", "--no-backup-if-mismatch", "--force", "--silent", "-r", "-"]
        if reverse:   cmdline.append("--reverse")
        if fuzz != 2: cmdline.append("--fuzz=%d" % fuzz)
        cmdline += [result.name, patchfile.name]

        exitcode = subprocess.call(cmdline, stdout=_devnull, stderr=_devnull)
        if exitcode != 0:
            raise PatchApplyError("Failed to apply patch (exitcode %d)." % exitcode)

        # Hack - we can't keep the file open while patching ('patch' might rename/replace
        # the file), so create a new _TemporaryFileWrapper object for the existing path.
        return tempfile._TemporaryFileWrapper(file=open(result.name, 'r+'), \
                                              name=result.name, delete=True)
    except:
        os.unlink(result.name)
        raise

def _preprocess_source(fp):
    """Simple C preprocessor to determine where we can safely add #ifdef instructions."""

    _re_state0 = re.compile("(\"|/[/*])")
    _re_state1 = re.compile("(\\\\\"|\\\\\\\\|\")")
    _re_state2 = re.compile("\\*/")

    # We need to read the original file, and figure out where lines can be splitted
    lines = []
    for line in fp:
        lines.append(line.rstrip("\n"))

    split = set([0])
    state = 0

    i = 0
    while i < len(lines):

        # Read a full line (and handle line continuation)
        line = lines[i]
        i += 1
        while line.endswith("\\"):
            if i >= len(lines):
                raise CParserError("Unexpected end of file.")
            line = line[:-1] + lines[i]
            i += 1

        # To find out where we can add our #ifdef tags we use a simple
        # statemachine. This allows finding the beginning of a multiline
        # instruction or comment.
        j = 0
        while True:

            # State 0: No context
            if state == 0:
                match = _re_state0.search(line, j)
                if match is None: break

                if match.group(0) == "\"":
                    state = 1 # Begin of string
                elif match.group(0) == "/*":
                    state = 2 # Begin of comment
                elif match.group(0) == "//":
                    break # Rest of the line is a comment, which can be safely ignored
                else:
                    assert 0

            # State 1: Inside of string
            elif state == 1:
                match = _re_state1.search(line, j)
                if match is None:
                    raise CParserError("Line ended in the middle of a string.")

                if match.group(0) == "\"":
                    state = 0 # End of string
                elif match.group(0) != "\\\"" and match.group(0) != "\\\\":
                    assert 0

            # State 2: Multiline comment
            elif state == 2:
                match = _re_state2.search(line, j)
                if match is None: break

                if match.group(0) == "*/":
                    state = 0 # End of comment
                else:
                    assert 0

            else:
                raise CParserError("Internal state error.")
            j = match.end()

        # Only in state 0 (no context) we can split here
        if state == 0:
            split.add(i)

    # Ensure that the last comment is properly terminated
    if state != 0:
        raise CParserError("Unexpected end of file.")
    return lines, split

def generate_ifdef_patch(original, patched, ifdef):
    """Generate a patch which adds #ifdef where necessary to keep both the original and patched version."""

    #
    # The basic of idea of this algorithm is as following:
    #
    # (1) determine diff between original file and patched file
    # (2) run the preprocessor, to determine where #ifdefs can be safely added
    # (3) use diff and preprocessor information to create a merged version containing #ifdefs
    # (4) create another diff to apply the changes on the patched version
    #

    with tempfile.NamedTemporaryFile(mode='w+') as diff:
        exitcode = subprocess.call(["git", "diff", "--no-index", "--minimal", "-U1", original.name, patched.name],
                                   stdout=diff, stderr=_devnull)
        if exitcode == 0:
            return None
        elif exitcode != 1:
            raise PatchDiffError("Failed to compute diff (exitcode %d)." % exitcode)

        # Preprocess the original C source
        original.seek(0)
        lines, split = _preprocess_source(original)

        # Parse the created diff file
        fp = _PatchReader(diff.name, diff)
        fp.seek(0)

        # We expect this output format from 'git diff', if this is not the case things might go wrong.
        line = fp.read()
        assert line.startswith("diff --git ")
        line = fp.read()
        assert line.startswith("index ")
        line = fp.read()
        assert line.startswith("--- ")
        line = fp.read()
        assert line.startswith("+++ ")

        hunks = []
        while fp.peek() is not None:
            srcpos, srcdata, dstpos, dstdata = fp.read_hunk()

            # Ensure that the patch would really apply in practice
            if lines[srcpos:srcpos + len(srcdata)] != srcdata:
                raise PatchParserError("Patch failed to apply.")

            # Strip common lines from the beginning and end
            while len(srcdata) > 0 and len(dstdata) > 0 and \
                    srcdata[0] == dstdata[0]:
                srcdata.pop(0)
                dstdata.pop(0)
                srcpos += 1
                dstpos += 1

            while len(srcdata) > 0 and len(dstdata) > 0 and \
                    srcdata[-1] == dstdata[-1]:
                srcdata.pop()
                dstdata.pop()

            # Ensure that diff generated valid output
            assert len(srcdata) > 0 or len(dstdata) > 0

            # If this is the first hunk, then check if we have to extend it at the beginning
            if len(hunks) == 0:
                assert srcpos == dstpos
                while srcpos > 0 and srcpos not in split:
                    srcpos -= 1
                    dstpos -= 1
                    srcdata.insert(0, lines[srcpos])
                    dstdata.insert(0, lines[srcpos])
                hunks.append((srcpos, dstpos, srcdata, dstdata))

            # Check if we can merge with the previous hunk
            else:
                prev_srcpos, prev_dstpos, prev_srcdata, prev_dstdata = hunks[-1]
                prev_endpos = prev_srcpos + len(prev_srcdata)

                found = 0
                for i in xrange(prev_endpos, srcpos):
                    if i in split:
                        found += 1

                # At least two possible splitting positions inbetween
                if found >= 2:
                    while prev_endpos not in split:
                        prev_srcdata.append(lines[prev_endpos])
                        prev_dstdata.append(lines[prev_endpos])
                        prev_endpos += 1

                    while srcpos not in split:
                        srcpos -= 1
                        srcdata.insert(0, lines[srcpos])
                        dstdata.insert(0, lines[srcpos])
                    hunks.append((srcpos, dstpos, srcdata, dstdata))

                # Merge hunks
                else:
                    while prev_endpos < srcpos:
                        prev_srcdata.append(lines[prev_endpos])
                        prev_dstdata.append(lines[prev_endpos])
                        prev_endpos += 1
                    assert prev_dstpos + len(prev_dstdata) == dstpos
                    prev_srcdata.extend(srcdata)
                    prev_dstdata.extend(dstdata)

            # Ready with this hunk
            pass

        # We might have to extend the last hunk
        if len(hunks):
            prev_srcpos, prev_dstpos, prev_srcdata, prev_dstdata = hunks[-1]
            prev_endpos = prev_srcpos + len(prev_srcdata)

            while prev_endpos < len(lines) and prev_endpos not in split:
                prev_srcdata.append(lines[prev_endpos])
                prev_dstdata.append(lines[prev_endpos])
                prev_endpos += 1

    # Generate resulting file with #ifdefs
    with tempfile.NamedTemporaryFile(mode='w+') as intermediate:

        pos = 0
        while len(hunks):
            srcpos, dstpos, srcdata, dstdata = hunks.pop(0)
            if pos < srcpos:
                intermediate.write("\n".join(lines[pos:srcpos]))
                intermediate.write("\n")

            if len(srcdata) and len(dstdata):
                intermediate.write("#if !defined(%s)\n" % ifdef)
                intermediate.write("\n".join(srcdata))
                intermediate.write("\n#else  /* %s */\n" % ifdef)
                intermediate.write("\n".join(dstdata))
                intermediate.write("\n#endif /* %s */\n" % ifdef)

            elif len(srcdata):
                intermediate.write("#if !defined(%s)\n" % ifdef)
                intermediate.write("\n".join(srcdata))
                intermediate.write("\n#endif /* %s */\n" % ifdef)

            elif len(dstdata):
                intermediate.write("#if defined(%s)\n" % ifdef)
                intermediate.write("\n".join(dstdata))
                intermediate.write("\n#endif /* %s */\n" % ifdef)

            else:
                assert 0
            pos = srcpos + len(srcdata)

        if pos < len(lines):
            intermediate.write("\n".join(lines[pos:]))
            intermediate.write("\n")
        intermediate.flush()

        # Now we can finally compute the diff between the original file and our intermediate file
        diff = tempfile.NamedTemporaryFile(mode='w+')
        exitcode = subprocess.call(["git", "diff", "--no-index", "--minimal", original.name, intermediate.name],
                                   stdout=diff, stderr=_devnull)
        if exitcode != 1: # exitcode 0 cannot (=shouldn't) happen in this situation
            raise PatchDiffError("Failed to compute diff (exitcode %d)." % exitcode)

        diff.flush()
        diff.seek(0)

        # We expect this output format from 'git diff', if this is not the case things might go wrong.
        line = diff.readline()
        assert line.startswith("diff --git ")
        line = diff.readline()
        assert line.startswith("index ")
        line = diff.readline()
        assert line.startswith("--- ")
        line = diff.readline()
        assert line.startswith("+++ ")

    # Return the final diff
    return diff

def escape_sh(text):
    """Escape string for usage in shell '...' quotes."""
    return text.replace("'", "'\\''")

def escape_c(text):
    """Escape string for usage in C "..." quotes."""
    return text.replace("\\", "\\\\").replace("\"", "\\\"")

if __name__ == "__main__":
    import unittest

    # Basic tests for _parse_author() and _parse_subject()
    class PatchParserTests(unittest.TestCase):
        def test_author(self):
            author = _parse_author("Author Name <author@email.com>")
            self.assertEqual(author, ("Author Name", "author@email.com"))

            author = _parse_author("=?UTF-8?q?Author=20Name?= <author@email.com>")
            self.assertEqual(author, ("Author Name", "author@email.com"))

        def test_subject(self):
            subject = _parse_subject("[PATCH v3] component: Subject.")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject (v3).")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject (try 3).")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject (take 3).")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject (rev 3).")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject [v3].")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject, v3.")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject v3.")
            self.assertEqual(subject, ("component: Subject", 3))

            subject = _parse_subject("[PATCH] component: Subject (resend).")
            self.assertEqual(subject, ("component: Subject", 1))

    # Basic tests for read_patch()
    class PatchReaderTests(unittest.TestCase):
        def test_simple(self):
            with open("tests/simple.patch") as fp:
                source = fp.read().split("\n")

            # Test formatted git patch with author and subject
            patchfile = tempfile.NamedTemporaryFile(mode='w+')
            patchfile.write("\n".join(source))
            patchfile.flush()

            patches = list(read_patch(patchfile.name))
            self.assertEqual(len(patches), 1)
            self.assertEqual(patches[0].patch_author,   "Author Name")
            self.assertEqual(patches[0].patch_email,    "author@email.com")
            self.assertEqual(patches[0].patch_subject,  "component: Replace arg1 with arg2.")
            self.assertEqual(patches[0].patch_revision, 3)
            self.assertEqual(patches[0].signed_off_by,  [("Author Name", "author@email.com"),
                                                         ("Other Developer", "other@email.com")])
            self.assertEqual(patches[0].filename,       patchfile.name)
            self.assertEqual(patches[0].is_binary,      False)
            self.assertEqual(patches[0].modified_file,  "test.txt")

            lines = patches[0].read().rstrip("\n").split("\n")
            self.assertEqual(lines, source[10:23])

            # Test with git diff
            del source[0:10]
            self.assertTrue(source[0].startswith("diff --git"))
            patchfile = tempfile.NamedTemporaryFile(mode='w+')
            patchfile.write("\n".join(source))
            patchfile.flush()

            patches = list(read_patch(patchfile.name))
            self.assertEqual(len(patches), 1)
            self.assertEqual(patches[0].patch_author,   None)
            self.assertEqual(patches[0].patch_email,    None)
            self.assertEqual(patches[0].patch_subject,  None)
            self.assertEqual(patches[0].patch_revision, 1)
            self.assertEqual(patches[0].signed_off_by,  [])
            self.assertEqual(patches[0].filename,       patchfile.name)
            self.assertEqual(patches[0].is_binary,      False)
            self.assertEqual(patches[0].modified_file, "test.txt")

            lines = patches[0].read().rstrip("\n").split("\n")
            self.assertEqual(lines, source[:13])

            # Test with unified diff
            del source[0:2]
            self.assertTrue(source[0].startswith("---"))
            patchfile = tempfile.NamedTemporaryFile(mode='w+')
            patchfile.write("\n".join(source))
            patchfile.flush()

            patches = list(read_patch(patchfile.name))
            self.assertEqual(len(patches), 1)
            self.assertEqual(patches[0].patch_author,   None)
            self.assertEqual(patches[0].patch_email,    None)
            self.assertEqual(patches[0].patch_subject,  None)
            self.assertEqual(patches[0].patch_revision, 1)
            self.assertEqual(patches[0].signed_off_by,  [])
            self.assertEqual(patches[0].filename,       patchfile.name)
            self.assertEqual(patches[0].is_binary,      False)
            self.assertEqual(patches[0].modified_file,  "test.txt")

            lines = patches[0].read().rstrip("\n").split("\n")
            self.assertEqual(lines, source[:11])

            # Test with StringIO buffer
            fp = StringIO("\n".join(source))
            patches = list(read_patch("unknown.patch", fp))
            self.assertEqual(len(patches), 1)
            self.assertEqual(patches[0].patch_author,   None)
            self.assertEqual(patches[0].patch_email,    None)
            self.assertEqual(patches[0].patch_subject,  None)
            self.assertEqual(patches[0].patch_revision, 1)
            self.assertEqual(patches[0].signed_off_by,  [])
            self.assertEqual(patches[0].filename,       "unknown.patch")
            self.assertEqual(patches[0].is_binary,      False)
            self.assertEqual(patches[0].modified_file,  "test.txt")

        def test_multi(self):
            with open("tests/multi.patch") as fp:
                source = fp.read().split("\n")

            patchfile = tempfile.NamedTemporaryFile(mode='w+')
            patchfile.write("\n".join(source))
            patchfile.flush()

            patches = list(read_patch(patchfile.name))
            self.assertEqual(len(patches), 3)

            self.assertEqual(patches[0].patch_author,   "Author Name")
            self.assertEqual(patches[0].patch_email,    "author@email.com")
            self.assertEqual(patches[0].patch_subject,  "component: Replace arg1 with arg2.")
            self.assertEqual(patches[0].patch_revision, 3)
            self.assertEqual(patches[0].signed_off_by,  [("Author Name", "author@email.com"),
                                                         ("Other Developer", "other@email.com")])
            self.assertEqual(patches[0].filename,       patchfile.name)
            self.assertEqual(patches[0].is_binary,      False)
            self.assertEqual(patches[0].modified_file,  "other_test.txt")

            lines = patches[0].read().rstrip("\n").split("\n")
            self.assertEqual(lines, source[11:24])

            self.assertEqual(patches[1].patch_author,   "Author Name")
            self.assertEqual(patches[1].patch_email,    "author@email.com")
            self.assertEqual(patches[1].patch_subject,  "component: Replace arg1 with arg2.")
            self.assertEqual(patches[1].patch_revision, 3)
            self.assertEqual(patches[1].signed_off_by,  [("Author Name", "author@email.com"),
                                                         ("Other Developer", "other@email.com")])
            self.assertEqual(patches[1].filename,       patchfile.name)
            self.assertEqual(patches[1].is_binary,      False)
            self.assertEqual(patches[1].modified_file,  "test.txt")

            lines = patches[1].read().rstrip("\n").split("\n")
            self.assertEqual(lines, source[24:46])

            self.assertEqual(patches[2].patch_author,   "Other Developer")
            self.assertEqual(patches[2].patch_email,    "other@email.com")
            self.assertEqual(patches[2].patch_subject,  "component: Replace arg2 with arg3.")
            self.assertEqual(patches[2].patch_revision, 4)
            self.assertEqual(patches[2].signed_off_by,  [("Other Developer", "other@email.com")])
            self.assertEqual(patches[2].filename,       patchfile.name)
            self.assertEqual(patches[2].is_binary,      False)
            self.assertEqual(patches[2].modified_file,  "test.txt")

            lines = patches[2].read().rstrip("\n").split("\n")
            self.assertEqual(lines, source[58:71])

    # Basic tests for apply_patch()
    class PatchApplyTests(unittest.TestCase):
        def test_apply(self):
            source = ["line1();", "line2();", "line3();",
                      "function(arg1);",
                      "line5();", "line6();", "line7();"]
            original = tempfile.NamedTemporaryFile(mode='w+')
            original.write("\n".join(source + [""]))
            original.flush()

            source = ["@@ -1,7 +1,7 @@",
                      " line1();", " line2();", " line3();",
                      "-function(arg1);",
                      "+function(arg2);",
                      " line5();", " line6();", " line7();"]
            patchfile = tempfile.NamedTemporaryFile(mode='w+')
            patchfile.write("\n".join(source + [""]))
            patchfile.flush()

            expected = ["line1();", "line2();", "line3();",
                        "function(arg2);",
                        "line5();", "line6();", "line7();"]
            result = apply_patch(original, patchfile, fuzz=0)
            lines = result.read().rstrip("\n").split("\n")
            self.assertEqual(lines, expected)

            expected = ["line1();", "line2();", "line3();",
                        "function(arg1);",
                        "line5();", "line6();", "line7();"]
            result = apply_patch(result, patchfile, reverse=True, fuzz=0)
            lines = result.read().rstrip("\n").split("\n")
            self.assertEqual(lines, expected)

    # Basic tests for _preprocess_source()
    class PreprocessorTests(unittest.TestCase):
        def test_preprocessor(self):
            source = ["int a; // comment 1",
                      "int b; // comment 2 \\",
                      "          comment 3 \\",
                      "          comment 4",
                      "int c; // comment with \"quotes\"",
                      "int d; // comment with /* c++ comment */",
                      "int e; /* multi \\",
                      "          line",
                      "          comment */",
                      "char *x = \"\\\\\";",
                      "char *y = \"abc\\\"def\";",
                      "char *z = \"multi\" \\",
                      "          \"line\"",
                      "          \"string\";"]
            lines, split = _preprocess_source(source)
            self.assertEqual(lines, source)
            self.assertEqual(split, set([0, 1, 4, 5, 6, 9, 10, 11, 13, 14]))

    # Basic tests for generate_ifdef_patch()
    class GenerateIfdefPatchTests(unittest.TestCase):
        def test_ifdefined(self):
            source = ["line1();", "line2();", "line3();",
                      "function(arg1, \\",
                      "         arg2, \\",
                      "         arg3);",
                      "line5();", "line6();", "line7();"]
            source1 = tempfile.NamedTemporaryFile(mode='w+')
            source1.write("\n".join(source + [""]))
            source1.flush()

            source = ["line1();", "line2();", "line3();",
                      "function(arg1, \\",
                      "         new_arg2, \\",
                      "         arg3);",
                      "line5();", "line6();", "line7();"]
            source2  = tempfile.NamedTemporaryFile(mode='w+')
            source2.write("\n".join(source + [""]))
            source2.flush()

            diff = generate_ifdef_patch(source1, source1, "PATCHED")
            self.assertEqual(diff, None)

            diff = generate_ifdef_patch(source2, source2, "PATCHED")
            self.assertEqual(diff, None)

            expected = ["@@ -1,9 +1,15 @@",
                        " line1();", " line2();", " line3();",
                        "+#if !defined(PATCHED)",
                        " function(arg1, \\",
                        "          arg2, \\",
                        "          arg3);",
                        "+#else  /* PATCHED */",
                        "+function(arg1, \\",
                        "+         new_arg2, \\",
                        "+         arg3);",
                        "+#endif /* PATCHED */",
                        " line5();", " line6();", " line7();"]
            diff = generate_ifdef_patch(source1, source2, "PATCHED")
            lines = diff.read().rstrip("\n").split("\n")
            self.assertEqual(lines, expected)

            expected = ["@@ -1,9 +1,15 @@",
                        " line1();", " line2();", " line3();",
                        "+#if !defined(PATCHED)",
                        " function(arg1, \\",
                        "          new_arg2, \\",
                        "          arg3);",
                        "+#else  /* PATCHED */",
                        "+function(arg1, \\",
                        "+         arg2, \\",
                        "+         arg3);",
                        "+#endif /* PATCHED */",
                        " line5();", " line6();", " line7();"]
            diff = generate_ifdef_patch(source2, source1, "PATCHED")
            lines = diff.read().rstrip("\n").split("\n")
            self.assertEqual(lines, expected)

    # Basic tests for escape_sh()
    class EscapeShellTests(unittest.TestCase):
        ascii = "".join([chr(i) for i in range(32, 127)] + \
                        ["\\" + chr(i) for i in range(32, 127)])

        def test_bash(self):
            script = tempfile.NamedTemporaryFile(mode='w+', delete=False)
            try:
                script.write("#!/bin/bash\nprintf '%%s\\n' '%s'\n" % escape_sh(self.ascii))
                script.close()

                st = os.stat(script.name)
                os.chmod(script.name, st.st_mode | stat.S_IEXEC)

                result = subprocess.check_output([script.name])
                if not isinstance(result, str): result = result.decode('utf-8')
                self.assertEqual(result.rstrip("\n"), self.ascii)
            finally:
                os.unlink(script.name)

        @unittest.skipUnless(os.path.exists("/bin/dash"), "requires Dash")
        def test_dash(self):
            script = tempfile.NamedTemporaryFile(mode='w+', delete=False)
            try:
                script.write("#!/bin/dash\nprintf '%%s\\n' '%s'\n" % escape_sh(self.ascii))
                script.close()

                st = os.stat(script.name)
                os.chmod(script.name, st.st_mode | stat.S_IEXEC)

                result = subprocess.check_output([script.name])
                if not isinstance(result, str): result = result.decode('utf-8')
                self.assertEqual(result.rstrip("\n"), self.ascii)
            finally:
                os.unlink(script.name)

    # Basic tests for escape_c()
    class EscapeCTests(unittest.TestCase):
        ascii = "".join([chr(i) for i in range(32, 127)] + \
                        ["\\" + chr(i) for i in range(32, 127)])

        def test_c(self):
            source = tempfile.NamedTemporaryFile(mode='w+', suffix='.c', delete=False)
            try:
                source.write("#include <stdio.h>\n")
                source.write("int main() { printf(\"%%s\", \"%s\"); return 0; }\n" % escape_c(self.ascii))
                source.close()

                compiled = tempfile.NamedTemporaryFile(mode='w+', delete=False)
                try:
                    compiled.close()
                    subprocess.call(["gcc", source.name, "-o", compiled.name])

                    result = subprocess.check_output([compiled.name])
                    if not isinstance(result, str): result = result.decode('utf-8')
                    self.assertEqual(result.rstrip("\n"), self.ascii)
                finally:
                    os.unlink(compiled.name)
            finally:
                os.unlink(source.name)

    unittest.main()