#!/usr/bin/env python3
import io
import sys
import json
import string
import argparse


class AlphaStackErrror(Exception):
    pass


class AlphaStackVM:
    printmodes = {
        'a': 'abcdefghijklmnopqrstuvwxyz',
        'b': 'ABCDEFGHIJKLMNOPQRSTUVWXYZ',
        'c': '0123456789:;<=>?\x1a\x1b\x1c\x1d\x1e\x1f \x7f  ',
        'd': '!\"#$%&\'()*+,-./@[\\]^_`{|}~',
        'e': '\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19'
    }
    base = 26

    @staticmethod
    def is_letter(char):
        return char in string.ascii_lowercase

    def __init__(self):
        self.output = sys.stdout
        self.input = sys.stdin
        self.reset()
        self.iteration_limit = 0

    def set_string_input(self, string):
        self.input = io.StringIO(string)

    def reset(self):
        """
        Resets VM state to default
        """
        self.registers = {
            c: 'a' for c in string.ascii_lowercase
        }
        self.value_stack = []
        self.proc_stack = []
        self.loop_stop = False
        self.proc_stop = False
        self.halted = False

    def capture_output(self):
        self.output = io.StringIO()

    def alpha_to_num(self, char):
        return ord(char) - ord('a')

    def num_to_alpha(self, num):
        return chr(num % self.base + ord('a'))

    def print_char(self, char):
        """
        Modes:

        -   a   b   c   d   e
        a   a   A   0   ! NUL
        b   b   B   1   " x01
        c   c   C   2   # x02
        d   d   D   3   $ x03
        e   e   E   4   % x04
        f   f   F   5   & x05
        g   g   G   6   ' x06
        h   h   H   7   ( x07
        i   i   I   8   ) x08
        j   j   J   9   *  HT
        k   k   K   :   +  LF
        l   l   L   ;   , x0b
        m   m   M   <   - x0c
        n   n   N   =   .  CR
        o   o   O   >   / x0e
        p   p   P   ?   @ x0f
        q   q   Q x1a   [ x10
        r   r   R ESC  \\ x11
        s   s   S x1c   ] x12
        t   t   T x1d   ^ x13
        u   u   U x1e   _ x14
        v   v   V x1f   ` x15
        w   w   W  SP   { x16
        x   x   X x7f   | x17
        y   y   Y       } x18
        z   z   Z       ~ x19
        """
        mode = self.registers['p']
        if mode not in self.printmodes:
            return
        outchar = self.printmodes[mode][self.alpha_to_num(char)]
        self.output.write(outchar)

    def read_char(self):
        char = self.input.read(1)
        if char == '':
            self.registers['t'] = 'z'
            return

        for k, v in self.printmodes.items():
            if char in v:
                self.registers['t'] = k
                self.push(string.ascii_lowercase[v.index(char)])
                return

        self.registers['t'] = 'z'

    def push_token(self, char):
        self.value_stack.append(char)

    def pop_token(self):
        if len(self.value_stack) == 0:
            raise AlphaStackErrror("Stack underflow")
        return self.value_stack.pop(-1)

    def push(self, char):
        return self.push_token(char)

    def pop(self):
        return self.pop_token()

    def pop_num(self):
        """
        Numbers are based on the 'n' register:
        number of alphabytes = register value + 1.
        Highest byte first
        """
        count = self.register_num('n') + 1
        value = 0
        for i in range(count):
            value = value * self.base + self.alpha_to_num(self.pop())
        return value

    def push_num(self, value):
        count = self.register_num('n') + 1
        for i in range(count):
            self.push(self.num_to_alpha(value))
            value = value // self.base

    def compare_token(self, elem, mark):
        return elem == mark

    def define_proc(self):
        mark = self.pop_token()
        proc = []
        while self.value_stack:
            elem = self.pop_token()
            if self.compare_token(elem, mark):
                break
            proc.append(elem)
        self.push_proc(proc)

    def pop_proc(self):
        if len(self.proc_stack) == 0:
            raise AlphaStackErrror("Stack underflow")
        return self.proc_stack.pop(-1)

    def push_proc(self, proc):
        self.proc_stack.append(proc)

    def init_loop(self):
        self.loop_counter = self.iteration_limit

    def exec_procedure(self, proc):
        for char in proc:
            self.exec_instruction(char)
            if self.must_stop_proc():
                break

    def must_stop_proc(self):
        if self.halted or self.loop_stop:
            return True
        if self.proc_stop:
            self.proc_stop = False
            return True
        return False

    def break_procedure(self):
        self.proc_stop = True

    def must_stop_loop(self):
        if self.halted or self.proc_stop:
            return True
        if self.loop_stop:
            self.loop_stop = False
            return True
        self.loop_counter -= 1
        if self.loop_counter == 0:
            return True
        return False

    def break_loop(self):
        self.loop_stop = True

    def repeat(self, count, proc):
        self.init_loop()
        if count == 0:
            while not self.must_stop_loop():
                self.exec_procedure(proc)
        else:
            for i in range(count):
                self.exec_procedure(proc)
                if self.must_stop_loop():
                    break

    def register_num(self, char):
        return self.alpha_to_num(self.registers[char])

    def op_arithmetic(self):
        rhs = self.pop_num()
        lhs = self.pop_num()
        op = self.registers['a']
        result = 0
        match op:
            case 'a':
                result = lhs + rhs
            case 'b':
                result = lhs - rhs
            case 'c':
                result = lhs * rhs
            case 'd':
                result = lhs // rhs
            case 'e':
                result = lhs % rhs
            case 'f':
                result = lhs == rhs
            case 'g':
                result = lhs > rhs
            case 'h':
                result = lhs >= rhs
            case 'i':
                result = lhs <= rhs
            case 'j':
                result = lhs < rhs
            case 'k':
                result = lhs != rhs
        self.push_num(int(result))

    def exec_instruction(self, char):
        """
        num = number, byte counts depends on the value of 'n' register
        val = single letter
        reg = single letter representing a register

        Stack Before Instr  Stack After     Name    Effect
        val             r                   repeat  Pops a procedure and repeats it, 'a' repetitions means infinite loops.
        reg val         s                   set     Sets register value
                        t val?              typewr  Reads from input and sets the 't' register


        mark ... mark   u                   undo    Clear the stack until mark is found

        ... val1 val2   w                   wrap    Wrap values around. val1 is the number of values to wrap. val2 is the amount.
                        x                   execute Pop and execute procedure
        """
        if self.registers['l'] == 'a':
            if char == 'l':
                self.registers['l'] = 'b'
            elif char in string.ascii_lowercase:
                self.push(char)
            return

        match char:
            case 'a':
                self.op_arithmetic()
            case 'b':
                # Break
                self.break_loop()
            case 'c':
                # Copy
                index = self.alpha_to_num(self.pop())
                if index >= len(self.value_stack):
                    raise AlphaStackErrror("Stack underflow")
                self.value_stack.append(self.value_stack[-1 - index])
            case 'd':
                # Define proc
                self.define_proc()
            case 'e':
                # Execute specific
                if len(self.proc_stack) == 0:
                    raise AlphaStackErrror("Stack underflow")
                index = self.register_num('e')
                if index >= len(self.proc_stack):
                    index = -1
                self.exec_procedure(self.proc_stack[index])
            case 'f':
                # Flip
                size = self.alpha_to_num(self.pop())
                args = []
                for i in range(size):
                    args.append(self.pop())

                for arg in args:
                    self.push(arg)
                pass
            case 'g':
                # Get Register
                self.push(self.registers[self.pop()])
            case 'h':
                # Halt
                self.halt()
            case 'i':
                # If / else
                ifelse_count = self.register_num('i')
                iftrue = self.pop_proc()
                cond = self.pop_num()
                to_exec = None
                if cond:
                    to_exec = iftrue
                    condition_matched = True
                for i in range(0, ifelse_count - 1):
                    iftrue = self.pop_proc()
                    cond = self.pop_num()
                    if cond and to_exec is None:
                        to_exec = iftrue
                        condition_matched = True
                if ifelse_count > 0:
                    iffalse = self.pop_proc()
                    if to_exec is None:
                        to_exec = iffalse
                if to_exec is not None:
                    self.exec_procedure(to_exec)
            case 'j':
                pass
            case 'k':
                # Keep procedure
                count = self.register_num('k') + 1
                if count > len(self.proc_stack):
                    raise AlphaStackErrror("Stack underflow")
                for proc in self.proc_stack[-count:]:
                    self.push_proc(proc)
            case 'l':
                # Literal mode
                self.registers['l'] = 'a'
            case 'm':
                # Memory
                self.push(self.registers['m'])
            case 'n':
                self.push_num(len(self.value_stack))
            case 'o':
                # Overwrite
                replace = self.pop_token()
                search = self.pop_token()
                mark = self.pop_token()
                values = []
                keep = self.registers['o'] != 'a'
                while self.value_stack:
                    elem = self.pop_token()
                    if self.compare_token(elem, mark):
                        break
                    values.insert(
                        0,
                        self.replaced_token(elem, replace)
                        if self.compare_token(elem, search)
                        else elem
                    )
                if keep:
                    values.insert(0, mark)
                    values.append(mark)
                for value in values:
                    self.push_token(value)
            case 'p':
                # Print
                self.print_char(self.pop())
            case 'q':
                pass
            case 'r':
                # Repeat
                proc = self.pop_proc()
                count = self.alpha_to_num(self.pop())
                self.repeat(count, proc)
            case 's':
                # Set Register
                val = self.pop()
                self.registers[self.pop()] = val
            case 't':
                # Text input
                self.read_char()
            case 'u':
                # Undo
                mark = self.pop()
                while self.value_stack:
                    elem = self.pop()
                    if elem == mark:
                        break
            case 'v':
                pass
            case 'w':
                # Wrap
                amount = self.alpha_to_num(self.pop())
                size = self.alpha_to_num(self.pop())
                args = []
                for i in range(size):
                    args.insert(0, self.pop())

                for i in range(amount):
                    args.insert(0, args.pop())

                for arg in args:
                    self.push(arg)
            case 'x':
                # Execute
                self.exec_procedure(self.pop_proc())
            case 'y':
                # Yield
                self.break_procedure()
            case 'z':
                pass

    def stack_string(self):
        return "".join(self.value_stack)

    def proc_stack_strings(self):
        return ("".join(proc) for proc in self.proc_stack)

    def debug_print(self):
        print("Stack: " + self.stack_string())
        print("Regs:")
        print(" ".join(r for r in vm.registers.keys()))
        print(" ".join(r for r in vm.registers.values()))
        print("Procs:")
        if self.proc_stack:
            print("\n".join(
                "%s: %s" % (self.num_to_alpha(ind), proc)
                for ind, proc in enumerate(self.proc_stack_strings())
            ))

    def halt(self):
        self.halted = True

    def exec_program(self, code):
        self.reset()
        self.exec_procedure(code)

    def tokenize(self, code):
        for i, char in enumerate(code):
            if self.is_letter(char):
                yield (i, char)

    def replaced_token(self, elem, replace):
        return replace


class StepperVM(AlphaStackVM):
    """
    This VM doesn't perform loops, instead pushes stuff onto the stack
    """

    def reset(self):
        super().reset()
        self.current_token = (-1, '')
        self.exec_stack = []

    def push(self, char):
        self.value_stack.append((self.current_token[0], char))

    def pop(self):
        popped = self.pop_token()
        return popped[1]

    def exec_procedure(self, proc):
        self.exec_stack.append(('p', list(proc), 0))

    def compare_token(self, elem, mark):
        return elem[1] == mark[1]

    def break_procedure(self):
        self.exec_stack.pop()

    def repeat(self, count, proc):
        self.exec_stack.append((
            'r',
            proc,
            count if count > 0 else -1
        ))

    def break_loop(self):
        while len(self.exec_stack) > 0:
            exec_info = self.exec_stack.pop()
            if exec_info[0] == 'r':
                break

    def exec_token(self, token):
        self.current_token = token
        self.exec_instruction(token[1])

    def tokens_to_string(self, proc):
        return "".join(c[1] for c in proc)

    def stack_string(self):
        return self.tokens_to_string(self.value_stack)

    def proc_stack_strings(self):
        return map(self.tokens_to_string, self.proc_stack)

    def debug_print(self):
        super().debug_print()
        print("Exec:")
        for type, proc, count in self.exec_stack:
            print("%s %2i %s" % (type, count, self.tokens_to_string(proc)))

    def halt(self):
        self.exec_stack.clear()

    def exec_program(self, code):
        self.reset()
        self.exec_procedure(self.tokenize(code))

    def next_token(self):
        while len(self.exec_stack) > 0:
            type, proc, count = self.exec_stack[-1]

            if type == 'r':
                # Loop
                if count > 0:
                    self.exec_stack.pop()
                    self.exec_stack.append(('r', proc, count - 1))
                    self.exec_procedure(proc)
                elif count == -1:
                    self.exec_procedure(proc)
                else:
                    # End of the loop
                    self.exec_stack.pop()
            else:
                # Procedure
                if len(proc) > 0:
                    token = proc.pop(0)
                    if len(proc) == 0:
                        self.exec_stack.pop()
                    return token

                # Empty procedure
                self.exec_stack.pop()

        if len(self.exec_stack) == 0:
            self.halted = True

        return None

    def op_arithmetic(self):
        count = self.register_num('n') + 1
        lhs_index = len(self.value_stack) - count * 2
        if lhs_index > 0:
            self.current_token = self.value_stack[lhs_index]
        super().op_arithmetic()

    def replaced_token(self, elem, replace):
        return (elem[0], replace[1])


parser = argparse.ArgumentParser()

parser.add_argument("--file", "-f", help="Source as file")
parser.add_argument("source", nargs="?", help="Source as string")
parser.add_argument("--input", "-i", help="Input string")
parser.add_argument("--test", help="Test file")
parser.add_argument("--debug", "-d", action="store_true", help="Print debug info")
parser.add_argument("--step", "-s", action="store_true", help="Step each instruction")


def execute_program(vm, code, debug):
    vm.exec_program(code)

    if isinstance(vm, StepperVM):
        if debug:
            vm.capture_output()

        pause = True
        while not vm.halted:
            token = vm.next_token()

            if not token:
                break

            if debug:
                # Clear screen
                print("\033[2J\033[H")

                print()
                pos = max(0, token[0])
                line_start = code.rfind('\n', 0, pos) + 1
                line_end = code.find('\n', pos)
                if line_end < 0:
                    line_end = len(code)
                print(code[line_start:line_end])
                print(" " * (pos - line_start) + "^")

            vm.exec_token(token)

            if debug:
                print("Token %s" % (token,))
                print("Output: %r" % vm.output.getvalue())
                vm.debug_print()

                if not pause:
                    pause = vm.register_num('l')

                if pause:
                    ch = input()
                    if ch == "q":
                        break
                    elif ch == "c":
                        pause = False
    else:
        if debug:
            print()
            vm.debug_print()


if __name__ == "__main__":
    args = parser.parse_args()
    if args.step:
        vm = StepperVM()
    else:
        vm = AlphaStackVM()

    if args.test is not None:
        with open(args.test) as f:
            data = json.load(f)
        fail = False
        for item in data:
            print(item["name"])
            vm.capture_output()
            vm.set_string_input(item["input"])
            try:
                execute_program(vm, item["code"], False)
                if vm.stack_string() != item["stack"]:
                    raise Exception("Stack mismatch")

                if vm.output.getvalue() != item["output"]:
                    raise Exception("Output mismatch")
            except Exception as e:
                print(e)
                print(item)
                vm.debug_print()
                print("Output: %r" % vm.output.getvalue())
                fail = True

        if fail:
            print("!! Some tests failed !!", file=sys.stderr)
            sys.exit(1)

    else:
        try:
            source = args.source
            if args.file:
                with open(args.file) as f:
                    source = f.read()

            if args.input is not None:
                vm.set_string_input(args.input)

            execute_program(vm, source, args.debug)
        except Exception as e:
            vm.debug_print()
            raise e
