#!/usr/bin/env python3

# This file is a part of the LinkAhead Project.
#
# Copyright (C) 2020 IndiScale GmbH <www.indiscale.com>
# Copyright (C) 2020 Daniel Hornung <d.hornung@indiscale.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Convert an SQL script to a Python with the same function signatures."""

import argparse
import sys

import sqlparse


MAINPAGE = """

/*! \\mainpage This autogenerated documentation from SQL.
 *
 * \\section function Global functions
 *
 * \\htmlonly The functions are listed <A href="globals.html">here</A>.\\endhtmlonly
 */

"""


def _parse_file(filename):
    """Parse and convert the file."""
    with open(filename, "r") as sqf:
        sql = sqlparse.parsestream(sqf)
        statements = list(sql)

    for stat in statements:
        sublists = list(stat.get_sublists())
        for i, ident in enumerate(sublists):
            if not (i > 0 and isinstance(ident, sqlparse.sql.Identifier)
                    and isinstance(sublists[i-1], sqlparse.sql.Comment)):
                continue
            comment = sublists[i-1]
            after_comment = stat.tokens[stat.tokens.index(comment) + 1]
            if not after_comment.value == "CREATE":
                # print(after_comment.value)
                continue

            print("/" * 78)
            print("// taken from {filename}".format(filename=filename))
            print("/" * 78)
            # print(type(ident))
            # print(ident)
            # print(ident.tokens[-1])
            # print(ident.tokens[-1].tokens[1].tokens)
            _print_header(comment.value, ident)


def _strip_whitespace(tokens):
    """Remove whitespace tokens from the tokens list and return that."""
    to_delete = []
    for i, tok in enumerate(tokens):
        if isinstance(tok, sqlparse.sql.Token) and tok.ttype in (
                sqlparse.tokens.Whitespace, sqlparse.tokens.Newline):
            to_delete.append(i)
    while to_delete:
        tokens.pop(to_delete.pop())
    return tokens


def _print_header(comment, ident, outfile=sys.stdout):
    """Print a C header with the given comment and identifier.

Parameters
----------
comment : str

ident : sqlparse.sql.Identifier
    """
    func = ident.token_matching(lambda tok: isinstance(tok, sqlparse.sql.Function), idx=0)
    _strip_whitespace(func.tokens)
    name, paren = func.tokens[0:2]
    assert isinstance(paren, sqlparse.sql.Parenthesis)

    args = []
    # if name.value == "insert_single_child_version":
    #     from IPython import embed; embed()
    _strip_whitespace(paren.tokens)

    # Expand nested elements
    to_expand = {}
    for tok in paren.tokens:
        if isinstance(tok, sqlparse.sql.IdentifierList):
            to_expand[tok] = _strip_whitespace(tok.tokens)
    for old in to_expand:
        pivot = paren.tokens.index(old)
        paren.tokens = paren.tokens[:pivot] + to_expand[old] + paren.tokens[pivot+1:]


    start = 0
    end = 0
    while end + 1 < len(paren.tokens):
        start = end + 1
        punct = paren.token_matching(lambda t: t.ttype == sqlparse.tokens.Punctuation, idx=end+1)
        if not punct:
            continue
        end = paren.tokens.index(punct)
        if start == end:
            continue
        tokens = paren.tokens[start: end]
        arg_name = _arg_from_tokens(tokens)
        args.append(arg_name)
    print(comment.strip(), file=outfile)
    print("void {name}(void* {args});\n".format(name=name, args=", void* ".join(args)),
          file=outfile)


def _arg_from_tokens(tokens):
    """Extract and return an argument name."""
    # Remove content of functions
    for i, tok in enumerate(tokens):
        if isinstance(tok, sqlparse.sql.Function):
            tokens[i] = tok.token_first()

    arg_name = "_".join([tok.value for tok in tokens])
    return arg_name.strip()


def _parse_arguments():
    """Parse the arguments."""
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-i', '--input', help="Input file(s)", required=True, nargs="+")

    return parser.parse_args()


def main():
    """The main function of this script."""
    args = _parse_arguments()
    print(MAINPAGE)
    for filename in args.input:
        _parse_file(filename=filename)


if __name__ == "__main__":
    main()
