#!/usr/bin/python3

"""
Stupidly overengineered script to help decide which commits to cherry-pick.

glibc upstream maintains a release branch for each release
(release/$VER/master) which gets backports of selected commits from
the master branch. However, upstream's policy is just a little less
conservative than feels appropriate for an SRU so I'm not comfortable
just taking all updates from the branch (in particular, many of the
commits that are being backported have not yet been part of a glibc
release).

This script is designed to help decide which patchsets to include in an
upload. The decisions are recorded in a hopefully mostly-readable file
state.txt alongside this file.  Patchsets can be in four states:

 * NEW -- no decision has been made about this yet
 * NO -- we have decided not to include this
 * YES -- this is to be included
 * LATER -- we will proably include this in a subsequent upload

It has subcommands:

 * debian/maint/review import
 * debian/maint/review review
 * debian/maint/review export
 * debian/maint/review dch

## import

This scans the upstream release branch (assumed to be the 'origin'
remote in the same repo) and adds any new ones to the state.txt
file. It applies some heuristics to group commits into patchsets

## review

This is an interactive command that allows review of the commits in
state.txt (and the most over-engineered part of all this). There are
keypresses to allow navigation, moving, joining, splitting and
reviewing patchsets (press 'h' to get started).

## export

This synchronizes debian/patches with state.txt.

## dch

This constructs a changelog entry for the newly added patches
(i.e. ones that are not yet mentioned in the changelog).
"""


import attr
import enum
import os
import pathlib
import shutil
import signal
import subprocess
import sys
import termios
import tempfile
import tty
import uuid
from typing import List


SCRIPT_DIR = pathlib.Path(__file__).resolve().parent
SOURCE_DIR = SCRIPT_DIR.parent.parent
STATE_FILE = SCRIPT_DIR.joinpath('state.txt')


class ReviewState(enum.Enum):

    NEW = "?"
    NO = "N"
    YES = "Y"
    LATER = "L"


@attr.s(auto_attribs=True)
class PatchSet:
    _state: "State" = None
    index: int = -1

    hashes: List[str] = attr.Factory(list)
    messages: List[str] = attr.Factory(list)
    bugs: List[str] = attr.Factory(list)
    review_state: ReviewState = ReviewState.NEW
    comment: str = ''
    first_patch_name: str = ''

    def prev(self):
        if self.index == 0:
            return None
        return self._state[self.index - 1]

    def next(self):
        if self.index >= len(self._state) - 1:
            return None
        return self._state[self.index + 1]


class State:

    def __init__(self):
        self._patchsets = []

    def __iter__(self):
        return iter(self._patchsets)

    def __len__(self):
        return len(self._patchsets)

    def __getitem__(self, thing):
        return self._patchsets[thing]

    def add(self, patchset, index=None):
        patchset._state = self
        if index is None:
            patchset.index = len(self._patchsets)
            self._patchsets.append(patchset)
        else:
            patchset.index = index
            self._patchsets.insert(index, patchset)
            for p2 in self._patchsets[index+1:]:
                p2.index += 1

    def remove(self, patchset):
        for p2 in self._patchsets[patchset.index+1:]:
            p2.index -= 1
        del self._patchsets[patchset.index]

    def with_state(self, *states):
        return [ps for ps in self._patchsets if ps.review_state in states]

    def load(self, path):
        self.load_path = path
        try:
            fp = path.open()
        except FileNotFoundError:
            return

        blocks = []
        with fp:
            cur = []
            for line in fp:
                line = line.strip()
                if not line:
                    if cur:
                        blocks.append(cur)
                    cur = []
                else:
                    cur.append(line)
            if cur:
                blocks.append(cur)

        for block in blocks:
            hashes = []
            messages = []
            bugs = []
            state = None
            comment = ''
            state = getattr(ReviewState, block[0])
            first_patch_name = ''
            for line in block[1:]:
                if line[0] == '#':
                    comment += line[2:] + '\n'
                elif line.startswith("LP#"):
                    bugs.append(line[3:])
                elif line.startswith("FirstPatch: "):
                    first_patch_name = line.split(None, 1)[1]
                else:
                    hash, message = line.split(None, 1)
                    hashes.append(hash)
                    messages.append(message)
            self.add(
                PatchSet(
                    hashes=hashes,
                    review_state=state,
                    comment=comment,
                    messages=messages,
                    bugs=bugs,
                    first_patch_name=first_patch_name))

    def save(self, path=None):
        if path is None:
            path = self.load_path
        with path.open('w') as fp:
            yes_patches = []
            other_patches = []
            for patchset in self._patchsets:
                if patchset.review_state == ReviewState.YES:
                    yes_patches.append(patchset)
                else:
                    other_patches.append(patchset)
            for patchset in yes_patches + other_patches:
                fp.write(patchset.review_state.name + '\n')
                if patchset.comment:
                    for line in patchset.comment.splitlines():
                        fp.write('# {}\n'.format(line))
                for bug in patchset.bugs:
                    fp.write("LP#{}\n".format(bug))
                if patchset.first_patch_name:
                    fp.write(
                        "FirstPatch: {}\n".format(patchset.first_patch_name))
                if len(patchset.hashes) != len(patchset.messages):
                    raise Exception("erp {}".format(patchset))
                for hash, message in zip(patchset.hashes, patchset.messages):
                    fp.write('{} {}\n'.format(hash, message))
                fp.write('\n')


V = "2.35"


def getch():
    fd = sys.stdin.fileno()
    old_settings = termios.tcgetattr(fd)
    try:
        tty.setcbreak(fd)
        ch = sys.stdin.read(1)
    finally:
        termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
    return ch


def getoption(prompt, keymap):
    orig = keymap
    sys.stdout.write(prompt)
    sys.stdout.flush()
    while True:
        c = getch()
        try:
            keymap = keymap[c]
        except KeyError:
            keymap = orig
        else:
            if not isinstance(keymap, dict):
                return keymap


def getupdown():
    prompt = '[(u)p (d)own (c)ancel]: '
    keymap = {'u': 'up', 'd': 'down', 'c': 'cancel'}
    return getoption(prompt, keymap)


def git(cmd):
    env = os.environ.copy()
    env['LESS'] = 'RX'
    try:
        subprocess.check_call(['git'] + cmd, cwd=SOURCE_DIR, env=env)
    except subprocess.CalledProcessError as e:
        if e.returncode == -signal.SIGPIPE:
            return
        raise


def git_output(cmd):
    return subprocess.check_output(
        ['git'] + cmd, cwd=SOURCE_DIR, encoding='utf-8')


cmds = {}


def command(name=None):
    def w(func):
        if name is None:
            cmds[func.__name__] = func
        else:
            cmds[name] = func
        return func
    return w


@command("import")
def import_(state):
    # Update state with new commits on release branch.
    seen_hashes = set()
    for patchset in state._patchsets:
        for hash in patchset.hashes:
            seen_hashes.add(hash)

    output = git_output([
        'log', '--reverse', '--format=%H %ct %s',
        # If I'd written this script before 2.35 got uploaded, this line would
        # be:
        # f'glibc-{V}..origin/release/{V}/master'
        # but here we hard code the last revision that became part of
        # debian/patches/git-updates.diff:
        f'e30c1c73da3f220f5bf0063a1d0344f8280311e7..origin/release/{V}/master'
        ])
    imported = 0
    limit = 0
    hashes = []
    messages = []
    for line in output.splitlines():
        hash, timestamp, message = line.split(None, 2)
        if hash in seen_hashes:
            continue
        timestamp = int(timestamp)
        if timestamp <= limit:
            hashes.append(hash)
            messages.append(message)
        else:
            if hashes:
                state.add(PatchSet(hashes=hashes, messages=messages))
                imported = imported + 1
            hashes = [hash]
            messages = [message]
        limit = timestamp + 60
    if hashes:
        state.add(PatchSet(hashes=hashes, messages=messages))
        imported = imported + 1
    print("Imported patch sets: {}".format(imported))


review_opts = {}
review_keymap = {}
keys = {
    'up': '\033[A',
    'down': '\033[B',
    }


def opt(shortcut=None):
    def w(func):
        if shortcut is not None:
            s = shortcut
        else:
            s = func.__name__[0]
        review_opts[s] = func.__name__.replace('_', ' ')
        review_keymap[keys.get(s, s)] = func
        return staticmethod(func)
    return w


class ReviewOpts:

    @opt()
    def diff(patchset):
        git(['show'] + patchset.hashes)

    @opt()
    def log(patchset):
        git(['show', '--no-patch'] + patchset.hashes)

    @opt('up')
    def prev(patchset):
        return patchset.index - 1

    @opt('down')
    def next(patchset):
        return patchset.index + 1

    @opt('N')
    def next_unreviewed(patchset):
        for ps in patchset._state[patchset.index+1:]:
            if ps.review_state == ReviewState.NEW:
                return ps.index
        return None

    @opt()
    def first(patchset):
        return 0

    @opt()
    def end(patchset):
        return len(patchset._state) - 1

    @opt('/')
    def find(patchset):
        needle = input('search for: ')
        if not needle:
            return
        for ps in patchset._state[patchset.index+1:]:
            for message in ps.messages:
                if needle in message:
                    return ps.index

    @opt()
    def comment(patchset):
        with tempfile.NamedTemporaryFile('w') as fp:
            fp.write(patchset.comment)
            fp.flush()
            subprocess.run(['vim', fp.name])
            with open(fp.name) as fp2:
                patchset.comment = fp2.read()

    @opt('b')
    def edit_bugs(patchset):
        with tempfile.NamedTemporaryFile('w') as fp:
            for bug in patchset.bugs:
                fp.write(bug + "\n")
            fp.flush()
            subprocess.run(['vim', fp.name])
            with open(fp.name) as fp2:
                patchset.bugs = [line.strip() for line in fp2]

    @opt()
    def no(patchset):
        patchset.review_state = ReviewState.NO
        return patchset.index + 1

    @opt('?')
    def new(patchset):
        patchset.review_state = ReviewState.NEW
        return patchset.index + 1

    @opt()
    def yes(patchset):
        patchset.review_state = ReviewState.YES
        return patchset.index + 1

    @opt()
    def later(patchset):
        patchset.review_state = ReviewState.LATER
        return patchset.index + 1

    @opt()
    def split(patchset):
        if len(patchset.hashes) <= 1:
            return
        if len(patchset.hashes) > 2:
            subprocess.run('clear')
            print()
            log(patchset, SPLIT_FORMAT)
            while True:
                at = input("Split after: ")
                if not at:
                    return
                try:
                    at = int(at)
                except ValueError:
                    continue
                break
        else:
            at = 1
        if at <= 0 or at >= len(patchset.hashes):
            return
        new_hashes = patchset.hashes[at:]
        del patchset.hashes[at:]
        new_messages = patchset.messages[at:]
        del patchset.messages[at:]
        patchset._state.add(
            PatchSet(
                hashes=new_hashes,
                messages=new_messages,
                review_state=patchset.review_state),
            patchset.index+1)

    @opt()
    def join(patchset):
        dir = getupdown()
        if dir == 'cancel':
            return
        elif dir == 'down':
            patchset = patchset.next()
            r = None
        else:
            r = patchset.index - 1
        patchset.prev().hashes += patchset.hashes
        patchset.prev().messages += patchset.messages
        patchset._state.remove(patchset)
        return r

    @opt()
    def move(patchset):
        i = patchset.index
        ps = patchset._state._patchsets
        dir = getupdown()
        if dir == 'cancel':
            return
        elif dir == 'up':
            r = patchset.index - 1
            d = -1
        elif dir == 'down':
            r = None
            d = 1
        other = ps[i + d]
        other.index, patchset.index = patchset.index, other.index
        ps[i], ps[i + d] = ps[i + d], ps[i]
        return r

    @opt()
    def quit(patchset):
        patchset._state.save()
        sys.exit()

    @opt('S')
    def save(patchset):
        patchset._state.save()

    @opt('x')
    def exit_no_save(patchset):
        sys.exit()

    @opt()
    def help(patchset):
        for k in sorted(review_opts):
            print(k, '-', review_opts[k])
        getch()


def compile_keymap(keymap):
    r = {}
    for key, value in keymap.items():
        first = key[:1]
        r.setdefault(first, {})[key[1:]] = value
    for key, value in r.items():
        if '' in value:
            if len(value) != 1:
                raise Exception(
                      "key definitions for %s clash" % (value.values(),))
            else:
                r[key] = value['']
        else:
            r[key] = compile_keymap(value)
    return r


review_prompt = '[' + ','.join(sorted(review_opts)) + ']: '
review_keymap = compile_keymap(review_keymap)


def _cc(r, g, b):
    return '8;2;{};{};{}'.format(r, g, b)


def _sgr(param):
    return '\x1b[0;3{}m'.format(param)


YELLOW = _sgr(_cc(255, 255, 0))
DIM = _sgr(_cc(128, 128, 128))
DEFAULT = _sgr(9)


DIM_FORMAT = '%s{status[0]} {hash} {message}%s' % (DIM, DEFAULT)
DEFAULT_FORMAT = '{status[0]} %s{hash}%s {message}' % (YELLOW, DEFAULT)
SPLIT_FORMAT = '{i} %s{hash}%s {message}' % (YELLOW, DEFAULT)


def log(patchset, format):
    if patchset is None:
        return
    for i, (hash, message) in enumerate(
            zip(patchset.hashes, patchset.messages), 1):
        print(format.format(
            i=i, status=patchset.review_state.value,
            hash=hash[:7], message=message))


@command()
def review(state):
    for index, patchset in enumerate(state):
        if patchset.review_state == ReviewState.NEW:
            break
    else:
        index = 0
    while True:
        patchset = state._patchsets[index]
        subprocess.run('clear')
        print()
        log(patchset.prev(), DIM_FORMAT)
        log(patchset, DEFAULT_FORMAT)
        log(patchset.next(), DIM_FORMAT)
        print()
        print('current state: ', patchset.review_state)
        if patchset.comment:
            print('comment:')
            print(patchset.comment)
        if patchset.bugs:
            print('bugs:', ','.join(patchset.bugs))
        print()
        meth = getoption(review_prompt, review_keymap)
        print()
        next = meth(patchset)
        if next is not None:
            if next < 0:
                next = 0
            if next >= len(state) - 1:
                next = len(state) - 1
            index = next


@command()
def export(state):
    # sync the state with debian/patches/upstream-updates
    series = SOURCE_DIR.joinpath('debian/patches/series')
    new_series = []
    with series.open() as fp:
        for line in fp:
            if not line.startswith('upstream-updates/'):
                new_series.append(line)

    updates = SOURCE_DIR.joinpath('debian/patches/upstream-updates')
    if updates.is_dir():
        shutil.rmtree(updates)
    updates.mkdir()

    to_export = state.with_state(ReviewState.YES)

    new_patches = []

    for index1, ps in enumerate(to_export, 1):
        for index2, hash in enumerate(ps.hashes, 1):
            with tempfile.TemporaryDirectory() as tdir:
                tdir = pathlib.Path(tdir)
                git(['format-patch', '--quiet', '-1', hash, '-o', tdir])
                patch = list(tdir.iterdir())[0]
                prefix = f'{index1:04}'
                if len(ps.hashes) > 1:
                    prefix += f'-{index2:02}'
                new_name = f'{prefix}{patch.name[4:]}'
                patch = patch.rename(updates.joinpath(new_name))
                print('generated', patch.relative_to(SOURCE_DIR))
            new_patches.append(f'upstream-updates/{new_name}\n')
            if index2 == 1:
                ps.first_patch_name = new_name

    with series.open('w') as fp:
        fp.write(''.join(new_patches))
        if new_series[0].strip():
            fp.write('\n')
        fp.write(''.join(new_series))


@command()
def dch(state):
    updates = SOURCE_DIR.joinpath('debian/patches/upstream-updates')
    patches = {patch.name for patch in updates.iterdir()}
    with SOURCE_DIR.joinpath('debian/changelog').open() as fp:
        for i, line in enumerate(fp):
            for p in patches:
                if p in line:
                    patches.remove(p)
                    break
    if not patches:
        return

    u = str(uuid.uuid4())
    subprocess.run(['dch', u], cwd=SOURCE_DIR)

    changelog_lines = []
    target_line = 0
    with SOURCE_DIR.joinpath('debian/changelog').open() as fp:
        for i, line in enumerate(fp):
            if u in line:
                target_line = i
            changelog_lines.append(line)
    content = '  * Cherry-pick patches from upstream maintenance branch:\n'
    by_patch_name = {
        ps.first_patch_name: ps for ps in state if ps.first_patch_name
        }
    for patch in sorted(patches):
        suffix = ''
        bugs = getattr(by_patch_name.get(patch), 'bugs', None)
        if bugs:
            suffix = ' (LP: {})'.format(",".join("#" + bug for bug in bugs))
        content += '    - ' + patch + suffix + '\n'
    changelog_lines[target_line] = content
    with SOURCE_DIR.joinpath('debian/changelog').open('w') as fp:
        fp.write(''.join(changelog_lines))


if len(sys.argv) != 2 or sys.argv[1] not in cmds:
    print("review import|review|export|dch")
    sys.exit(1)
cmd = cmds[sys.argv[1]]
global_state = State()
global_state.load(STATE_FILE)
cmd(global_state)
global_state.save()
