#!/usr/bin/env python3
import io
import sys
import json
import string
import argparse


class AlphaStackErrror(Exception):
    pass


class AphaStackVM:
    printmodes = {
        'a': 'abcdefghijklmnopqrstuvwxyz',
        'b': 'ABCDEFGHIJKLMNOPQRSTUVWXYZ',
        'c': '0123456789:;<=>?\x1a\x1b\x1c\x1d\x1e\x1f \x7f  ',
        'd': '!\"#$%&\'()*+,-./@[\\]^_`{|}~',
        'e': '\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19'
    }
    base = 26

    @staticmethod
    def is_letter(char):
        return char in string.ascii_lowercase

    def __init__(self):
        self.output = sys.stdout
        self.input = sys.stdin
        self.reset()

    def set_string_input(self, string):
        self.input = io.StringIO(string)

    def reset(self):
        """
        Resets VM state to default
        """
        self.registers = {
            c: 'a' for c in string.ascii_lowercase
        }
        self.value_stack = []
        self.proc_stack = []
        self.loop_stop = False
        self.proc_stop = False
        self.halted = False

    def capture_output(self):
        self.output = io.StringIO()

    def alpha_to_num(self, char):
        return ord(char) - ord('a')

    def num_to_alpha(self, num):
        return chr(num % self.base + ord('a'))

    def print_char(self, char):
        """
        Modes:

        -   a   b   c   d   e
        a   a   A   0   ! NUL
        b   b   B   1   " x01
        c   c   C   2   # x02
        d   d   D   3   $ x03
        e   e   E   4   % x04
        f   f   F   5   & x05
        g   g   G   6   ' x06
        h   h   H   7   ( x07
        i   i   I   8   ) x08
        j   j   J   9   *  HT
        k   k   K   :   +  LF
        l   l   L   ;   , x0b
        m   m   M   <   - x0c
        n   n   N   =   .  CR
        o   o   O   >   / x0e
        p   p   P   ?   @ x0f
        q   q   Q x1a   [ x10
        r   r   R ESC  \\ x11
        s   s   S x1c   ] x12
        t   t   T x1d   ^ x13
        u   u   U x1e   _ x14
        v   v   V x1f   ` x15
        w   w   W  SP   { x16
        x   x   X x7f   | x17
        y   y   Y       } x18
        z   z   Z       ~ x19
        """
        mode = self.registers['p']
        if mode not in self.printmodes:
            return
        outchar = self.printmodes[mode][self.alpha_to_num(char)]
        self.output.write(outchar)

    def read_char(self):
        char = self.input.read(1)
        if char == '':
            self.registers['t'] = 'z'
            return

        for k, v in self.printmodes.items():
            if char in v:
                self.registers['t'] = k
                self.push(string.ascii_lowercase[v.index(char)])
                return

        self.registers['t'] = 'z'

    def push(self, char):
        self.value_stack.append(char)

    def pop(self):
        if len(self.value_stack) == 0:
            raise AlphaStackErrror("Stack underflow")
        return self.value_stack.pop(-1)

    def pop_num(self):
        """
        Numbers are based on the 'n' register:
        number of alphabytes = register value + 1.
        Highest byte first
        """
        count = self.register_num('n') + 1
        value = 0
        for i in range(count):
            value = value * self.base + self.alpha_to_num(self.pop())
        return value

    def push_num(self, value):
        count = self.register_num('n') + 1
        for i in range(count):
            self.push(self.num_to_alpha(value))
            value = value // self.base

    def pop_proc(self):
        if len(self.proc_stack) == 0:
            raise AlphaStackErrror("Stack underflow")
        return self.proc_stack.pop(-1)

    def push_proc(self, proc):
        self.proc_stack.append(proc)

    def must_stop_loop(self):
        if self.halted or self.proc_stop:
            return True
        if self.loop_stop:
            self.loop_stop = False
            return True
        return False

    def must_stop_proc(self):
        if self.halted or self.loop_stop:
            return True
        if self.proc_stop:
            self.proc_stop = False
            return True
        return False

    def exec_procedure(self, proc):
        for char in proc:
            self.exec_instruction(char)
            if self.must_stop_proc():
                break

    def register_num(self, char):
        return self.alpha_to_num(self.registers[char])

    def op_arithmetic(self):
        rhs = self.pop_num()
        lhs = self.pop_num()
        op = self.registers['a']
        result = 0
        match op:
            case 'a':
                result = lhs + rhs
            case 'b':
                result = lhs - rhs
            case 'c':
                result = lhs * rhs
            case 'd':
                result = lhs // rhs
            case 'e':
                result = lhs % rhs
            case 'f':
                result = lhs == rhs
            case 'g':
                result = lhs > rhs
            case 'h':
                result = lhs >= rhs
            case 'i':
                result = lhs <= rhs
            case 'j':
                result = lhs < rhs
            case 'k':
                result = lhs != rhs
        self.push_num(int(result))

    def exec_instruction(self, char):
        """
        num = number, byte counts depends on the value of 'n' register
        val = single letter
        reg = single letter representing a register

        Stack Before Instr  Stack After     Name    Effect
        val             r                   repeat  Pops a procedure and repeats it, 'a' repetitions means infinite loops.
        reg val         s                   set     Sets register value
                        t val?              typewr  Reads from input and sets the 't' register


        mark ... mark   u                   undo    Clear the stack until mark is found

        ... val1 val2   w                   wrap    Wrap values around. val1 is the number of values to wrap. val2 is the amount.
                        x                   execute Pop and execute procedure
        """
        if self.registers['l'] == 'a':
            if char == 'l':
                self.registers['l'] = 'b'
            elif char in string.ascii_lowercase:
                self.push(char)
            return

        match char:
            case 'a':
                self.op_arithmetic()
            case 'b':
                # Break
                self.loop_stop = True
            case 'c':
                # Copy
                index = self.alpha_to_num(self.pop())
                if index >= len(self.value_stack):
                    raise AlphaStackErrror("Stack underflow")
                self.push(self.value_stack[-1 - index])
            case 'd':
                # Define proc
                mark = self.pop()
                proc = ""
                while self.value_stack:
                    elem = self.pop()
                    if elem == mark:
                        break
                    proc += elem
                self.push_proc(proc)
            case 'e':
                # Execute specific
                if len(self.proc_stack) == 0:
                    raise AlphaStackErrror("Stack underflow")
                index = self.register_num('e')
                if index >= len(self.proc_stack):
                    index = -1
                self.exec_procedure(self.proc_stack[index])
            case 'f':
                # Flip
                size = self.alpha_to_num(self.pop())
                args = []
                for i in range(size):
                    args.append(self.pop())

                for arg in args:
                    self.push(arg)
                pass
            case 'g':
                # Get Register
                self.push(self.registers[self.pop()])
            case 'h':
                # Halt
                self.halt()
            case 'i':
                # If / else
                ifelse_count = self.register_num('i')
                iftrue = self.pop_proc()
                cond = self.alpha_to_num(self.pop())
                to_exec = None
                if cond:
                    to_exec = iftrue
                    condition_matched = True
                for i in range(0, ifelse_count - 1):
                    iftrue = self.pop_proc()
                    cond = self.alpha_to_num(self.pop())
                    if cond and to_exec is None:
                        to_exec = iftrue
                        condition_matched = True
                if ifelse_count > 0:
                    iffalse = self.pop_proc()
                    if to_exec is None:
                        to_exec = iffalse
                if to_exec is not None:
                    self.exec_procedure(to_exec)
            case 'j':
                pass
            case 'k':
                # Keep procedure
                proc = self.pop_proc()
                self.push_proc(proc)
                self.push_proc(proc)
            case 'l':
                # Literal mode
                self.registers['l'] = 'a'
            case 'm':
                # Memory
                self.push(self.registers['m'])
            case 'n':
                pass
            case 'o':
                # Overwrite
                replace = self.pop()
                search = self.pop()
                mark = self.pop()
                values = []
                keep = self.registers['o'] != 'a'
                while self.value_stack:
                    elem = self.pop()
                    if elem == mark:
                        break
                    values.insert(0, replace if elem == search else elem)
                if keep:
                    values.insert(0, mark)
                    values.append(mark)
                for value in values:
                    self.push(value)
            case 'p':
                # Print
                self.print_char(self.pop())
            case 'q':
                pass
            case 'r':
                # Repeat
                proc = self.pop_proc()
                count = self.alpha_to_num(self.pop())
                if count == 0:
                    while not self.must_stop_loop():
                        self.exec_procedure(proc)
                else:
                    for i in range(count):
                        self.exec_procedure(proc)
                        if self.must_stop_loop():
                            break
            case 's':
                # Set Register
                val = self.pop()
                self.registers[self.pop()] = val
            case 't':
                # Text input
                self.read_char()
            case 'u':
                # Undo
                mark = self.pop()
                while self.value_stack:
                    elem = self.pop()
                    if elem == mark:
                        break
            case 'v':
                pass
            case 'w':
                # Wrap
                amount = self.alpha_to_num(self.pop())
                size = self.alpha_to_num(self.pop())
                args = []
                for i in range(size):
                    args.insert(0, self.pop())

                for i in range(amount):
                    args.insert(0, args.pop())

                for arg in args:
                    self.push(arg)
            case 'x':
                # Execute
                self.exec_procedure(self.pop_proc())
            case 'y':
                # Yield
                self.proc_stop = True
            case 'z':
                pass

    def stack_string(self):
        return "".join(self.value_stack)

    def debug_print(self):
        print("Stack: " + self.stack_string())
        print("Procs:")
        print("\n".join(self.proc_stack))
        print("Regs:")
        print(" ".join(r for r in vm.registers.keys()))
        print(" ".join(r for r in vm.registers.values()))

    def halt(self):
        self.halted = True


parser = argparse.ArgumentParser()
parser.add_argument("source", nargs="?")
parser.add_argument("--input", "-i", help="Input string")
parser.add_argument("--test", help="Test file")
parser.add_argument("--debug", "-d", action="store_true", help="Print debug info")
parser.add_argument("--step", "-s", action="store_true", help="Step each instruction")


if __name__ == "__main__":
    args = parser.parse_args()
    if args.test is not None:
        with open(args.test) as f:
            data = json.load(f)
        fail = False
        for item in data:
            print(item["name"])
            vm = AphaStackVM()
            vm.capture_output()
            vm.set_string_input(item["input"])
            try:
                vm.exec_procedure(item["code"])

                if vm.stack_string() != item["stack"]:
                    raise Exception("Stack mismatch")

                if vm.output.getvalue() != item["output"]:
                    raise Exception("Output mismatch")
            except Exception as e:
                print(e)
                print(item)
                vm.debug_print()
                print("Output: %r" % vm.output.getvalue())
                fail = True

        if fail:
            print("!! Some tests failed !!", file=sys.stderr)
            sys.exit(1)

    else:
        vm = AphaStackVM()
        try:
            if args.input is not None:
                vm.set_string_input(args.input)
            if args.step:
                for i, ch in enumerate(args.source):
                    print(args.source)
                    print(" " * i + "^")
                    vm.exec_instruction(ch)
                    vm.debug_print()
                    if vm.must_stop_proc():
                        break
                    print()
            else:
                vm.exec_procedure(args.source)
                if args.debug:
                    print()
                    vm.debug_print()
        except Exception as e:
            vm.debug_print()
            raise e
