import io
import sys
import copy
import enum
import typing
import argparse
import dataclasses


class Control(enum.Enum):
    FuncDef = enum.auto()
    FuncCall = enum.auto()
    Return = enum.auto()


@dataclasses.dataclass
class Instruction:
    token: str
    c_like: str
    effect: typing.Callable
    description: str


@dataclasses.dataclass
class Token:
    token: str
    start: int
    end: int
    line: int
    col: int


@dataclasses.dataclass
class InstructionInstance:
    instruction: Instruction
    token: Token


@dataclasses.dataclass
class Machine:
    stack: int = 0
    hold: int = 1
    memory: list[int] = dataclasses.field(default_factory=list)

    def __init__(self):
        self.reset()

    def reset(self):
        self.memory = []
        self.stack = 0
        self.hold = 1
        self.jump_args = None
        self.allocate_byte()

    def print(self, ch):
        sys.stdout.write(ch)

    def inc_stack(self):
        self.stack += 1
        if self.stack == len(self.memory):
            self.allocate_byte()

    def allocate_byte(self):
        self.memory.append(0)

    def dec_stack(self):
        self.stack -= 1
        if self.stack < 0:
            raise YipYapInternalError("You yapped too many times")

    def set_hold(self, value):
        self.hold = value % 256

    def swap(self):
        self.hold, self.memory[self.stack] = self.memory[self.stack], self.hold

    def jump(self, direction, condition, target, offset):
        self.jump_args = (direction, condition, target, offset)

    @property
    def head(self):
        return self.memory[self.stack]


class InstructionSet:
    def __init__(self, commands):
        self.commands = commands

    def __iter__(self):
        return iter(self.commands)

    def __len__(self):
        return len(self.commands)

    def compile(self, yip, yap):
        commands = {}
        for command in self.commands:
            token = self.get_token(command.token, yip, yap)
            commands[token] = command
        return commands

    @staticmethod
    def yip_or_yap(value, yip, yap):
        return yip if value == "yip" else yap

    @staticmethod
    def get_token(token, yip, yap):
        in_norm = token.lower().strip("?!")
        if len(in_norm) > 3:
            out = ""
            for i in range(0, len(in_norm), 3):
                out += InstructionSet.yip_or_yap(in_norm[i:i+3], yip, yap)
        else:
            out = InstructionSet.yip_or_yap(in_norm, yip, yap)
        if token.istitle():
            out = out.title()
        elif token.isupper():
            out = out.upper()
        if token[-1] != "p":
            out += token[-1]
        return out


class YipYapError(Exception):
    def __init__(self, msg, line, col):
        super().__init__(msg)
        self.line = line
        self.col = col


class YipYapInternalError(Exception):
    pass


class Interpreter:
    def __init__(self, instructions, machine=None):
        self.machine = machine or Machine()
        self.instructions = instructions
        self.file_instructions = []
        self.functions = {}
        self.program_counter = 0
        self.call_stack = []
        self.line = 0
        self.col = 0

    def get_char(self, file):
        self.col += 1
        self.token_pos += 1
        ch = file.read(1)
        if ch == "\n":
            self.col = 1
            self.line += 1
        return ch

    def next_token(self, file):
        token = ""

        ch = " "
        while ch.isspace():
            line = self.line
            col = self.col
            ch = self.get_char(file)

        start = self.token_pos - 1
        while ch != '' and not ch.isspace():
            token += ch
            ch = self.get_char(file)
        end = self.token_pos - 1

        if token.lower() in ("owo", "uwu"):
            while ch != '\n':
                ch = self.get_char(file)
            return self.next_token(file)

        return Token(token, start, end, line, col)

    def error(self, msg, token):
        raise YipYapError(msg, token.line, token.col) from None

    def init_execution(self, file):
        self.machine.reset()
        self.token_pos = 0
        self.line = self.col = 1
        yip = self.next_token(file)
        yap = self.next_token(file)
        if yip.token == yap.token or yap.token == "":
            self.error("Expected yip yap", yap)

        self.machine.yip_begin = self.instructions.get_token("yip?", yip.token, yap.token)
        self.machine.yip_end = self.instructions.get_token("yap!", yip.token, yap.token)
        self.commands = self.instructions.compile(yip.token, yap.token)

        self.load_program(file)

    def load_program(self, file):
        self.file_instructions = []
        self.functions = {}
        func_def = False

        while True:
            token = self.next_token(file)
            if not token.token:
                break

            cmd = self.commands.get(token.token)
            if cmd is None:
                self.error("%s is not yip nor yap" % token.token, token)

            instance = InstructionInstance(cmd, token)

            self.file_instructions.append(instance)
            if func_def:
                self.functions[token.token] = len(self.file_instructions)
                func_def = None
            else:
                func_def = cmd.effect == Control.FuncDef

        self.program_counter = 0

    @property
    def halted(self):
        return self.program_counter >= len(self.file_instructions)

    def halt(self):
        self.program_counter = len(self.file_instructions)

    def fetch_instruction(self):
        cmd = self.file_instructions[self.program_counter]
        self.program_counter += 1
        return cmd

    def push_call(self, caller, callee, pos):
        self.call_stack.append(self.program_counter + 1)
        self.program_counter = pos

    def pop_call(self):
        self.program_counter = self.call_stack.pop()

    def exec_instruction(self, cmd):
        match cmd.instruction.effect:
            case Control.Return:
                if len(self.call_stack) == 0:
                    self.halt()
                else:
                    self.pop_call()
                return
            case Control.FuncCall:
                token = self.file_instructions[self.program_counter].token
                pos = self.functions.get(token.token)
                if pos is None:
                    self.halt()
                else:
                    self.push_call(cmd.token, token, pos)
                return
            case Control.FuncDef:
                return

        try:
            cmd.instruction.effect(self.machine)
        except YipYapInternalError as e:
            self.error(str(e), cmd.token)

        if self.machine.jump_args:
            (direction, condition, target, offset) = self.machine.jump_args
            self.machine.jump_args = None
            perform_jump = True
            if condition is not None:
                perform_jump = self.machine.hold == condition
            if perform_jump:
                dest = list_index(self.file_instructions, target, self.program_counter, direction)
                if dest == -1:
                    self.error("Kobold can't jump far enough", cmd.token)
                self.program_counter = dest + offset

    def step(self):
        if self.halted:
            return

        cmd = self.fetch_instruction()
        self.exec_instruction(cmd)


def list_index(list, find, position, direction):
    if direction > 0:
        index_range = range(position, len(list), 1)
    else:
        index_range = range(min(position, len(list) - 1), -1, -1)
    for i in index_range:
        if list[i].token.token == find:
            return i
    return -1


def byte_to_char(val):
    if val >= 0x20 and val < 0x80:
        return chr(val)
    return ""


def print_memory(memory, offset, stack):
    for i in range(len(memory)):
        sys.stdout.write("%02x" % (i + offset))
        if i + offset == stack:
            sys.stdout.write("<")
        else:
            sys.stdout.write(" ")
    sys.stdout.write("\n")
    for v in memory:
        sys.stdout.write("%02x " % v)
    sys.stdout.write("\n")
    for v in memory:
        sys.stdout.write("%2s " % byte_to_char(v))
    sys.stdout.write("\n")


def debug_machine(machine):
    print("")
    print("Hold: %(v)02x %(v)3d %(c)s" % {
        "v": machine.hold,
        "c": byte_to_char(machine.hold)
    })
    print("SV:   %(v)02x %(v)3d %(c)s" % {
        "v": machine.head,
        "c": byte_to_char(machine.head)
    })
    print("SP:   %(v)04x %(v)d" % {
        "v": machine.stack,
    })
    print("")

    chunk_size = 0x20
    for i in range(0, len(machine.memory), chunk_size):
        print_memory(machine.memory[i:i+chunk_size], i, machine.stack)


@dataclasses.dataclass
class StackFrame:
    caller: Token
    callee: Token
    machine: Machine
    program_counter: int


class Debugger(Interpreter):
    def __init__(self, *args):
        super().__init__(*args)
        self.breakpoints = []
        self.paused = True
        self.break_at = None
        self.last_cmd = ""
        self.last_token = None
        self.lines = []

    def debug_dump(self, cmd):
        print(cmd.token)
        debug_machine(self.machine)

    def backtrace_line(self, i, token, machine):
        sys.stderr.write("#%s %s:%s %s\n" % (i, token.line, token.col, token.token))

    def backtrace(self):
        sys.stderr.write("Stack trace\n")
        for i, frame in enumerate(self.call_stack):
            self.backtrace_line(i, frame.callee, frame.machine)
        self.backtrace_line(i + 1, self.last_token, self.machine)

    def push_call(self, caller, callee, pos):
        self.call_stack.append(StackFrame(
            caller,
            callee,
            copy.deepcopy(self.machine),
            self.program_counter,
        ))
        self.program_counter = pos

    def pop_call(self):
        frame = self.call_stack.pop()
        self.program_counter = frame.program_counter + 1

    def init_execution(self, file):
        super().init_execution(file)
        file.seek(0)
        self.lines = list(file)

    def print_lines(self, start, end, higlight):
        for i in range(start - 1, end):
            if 0 <= i < len(self.lines):
                sys.stderr.write("%s%3d %s" % (">" if i + 1 == higlight else " ", i + 1, self.lines[i]))

    def line_to_token(self, arg):
        line = int(arg)
        for inst in self.file_instructions:
            if inst.token.line >= line:
                return inst.token
        return None

    def step(self):
        if self.halted:
            return
        try:
            cmd = self.fetch_instruction()
            self.last_token = cmd.token
            self.exec_instruction(cmd)
            if cmd.token.start == self.break_at:
                self.paused = True
                self.break_at = None
            else:
                for bp in self.breakpoints:
                    if bp.start == cmd.token.start:
                        self.paused = True
                        break
        except YipYapError as e:
            sys.stderr.write("Error: %s:%s: %s\n" % (e.line, e.col, e))
            self.backtrace()
            self.paused = True

        if self.paused:
            self.debug_dump(cmd)

            list_line = self.last_token.line
            self.print_lines(list_line - 2, list_line + 2, list_line)

            while True:
                dbcmd = input("(ydb) ").split(" ")
                if dbcmd[0] == '':
                    dbcmd = self.last_cmd

                self.last_cmd = dbcmd
                dbargs = dbcmd[1:]
                dbcmd = dbcmd[0]

                match dbcmd:
                    case "c"|"r"|"unt":
                        self.paused = False
                        if dbargs:
                            if dbargs[0] == "+1":
                                dbargs[0] = self.last_token.line + 1
                            bp = self.line_to_token(dbargs[0])
                            self.break_at = bp.start
                        return
                    case "s":
                        return
                    case "n":
                        bp = self.line_to_token(self.last_token.line + 1)
                        self.break_at = bp.start if bp else None
                        return
                    case "p":
                        self.debug_dump(cmd)
                    case "cl":
                        breakpoints = []
                        lines = set(map(int, dbargs))
                        for bp in self.breakpoints:
                            if bp.line not in line:
                                breakpoints.append(bp)
                        self.breakpoints = breakpoints
                    case "b":
                        if dbargs:
                            for arg in dbargs:
                                bp = self.line_to_token(arg)
                                if bp:
                                    self.breakpoints.append(bp)
                        else:
                            print(self.breakpoints)
                    case "bt":
                        self.backtrace()
                    case "q":
                        sys.exit(0)
                    case "l":
                        self.print_lines(list_line - 5, list_line + 5, list_line)
                        list_line += 10


instructions = InstructionSet([
    Instruction("yip",  "stack++",                  lambda m: m.inc_stack(),                    "Increase stack position"),
    Instruction("yap",  "stack--",                  lambda m: m.dec_stack(),                    "Decrease stack position"),
    Instruction("yip?", "if (!hold) goto next yap!",lambda m: m.jump(1, 0, m.yip_end, 1),       "Jump after the next yip! if hold is zero"),
    Instruction("yap?", "hold = min(*stack, hold)", lambda m: m.set_hold(min(m.hold, m.head)),  "Set hold to the stack value if it's less than hold"),
    #Instruction("yap?", "hold = min(*stack, hold)", lambda m: breakpoint(),                     "Set hold to the stack value if it's less than hold"),
    Instruction("yip!", "hold -= *stack",           lambda m: m.set_hold(m.hold - m.head),      "Subtract the stack value from hold"),
    Instruction("yap!", "goto prev yip?",           lambda m: m.jump(-1, None, m.yip_begin, 0), "Jump to the previous yip?"),
    Instruction("Yip",  "print(chr(hold))",         lambda m: m.print(chr(m.hold)),             "Output hold as character"),
    Instruction("Yap",  "hold += *stack",           lambda m: m.set_hold(m.hold + m.head),      "Add stack value to hold"),
    Instruction("Yip!", "print(hold)",              lambda m: m.print(str(m.hold)),             "Output hold as a number"),
    Instruction("Yap!", "return",                   Control.Return,                             "Returns from a function to its invocation point or terminate the script"),
    Instruction("Yip?", "func() {}",                Control.FuncDef,                            "Defines a function whose name is the next token"),
    Instruction("Yap?", "func()",                   Control.FuncCall,                           "Invoke the function represented by the next token"),
    Instruction("yipyip", "hold = *stack",          lambda m: m.set_hold(m.head),               "Assign the stack value to hold"),
    Instruction("yipyap", "swap(hold, *stack)",     lambda m: m.swap(),                         "Swap hold and stack value"),
    Instruction("yapyip", "hold++",                 lambda m: m.set_hold(m.hold + 1),           "Increase hold value by 1"),
    Instruction("yapyap", "hold--",                 lambda m: m.set_hold(m.hold - 1),           "Decrease hold value by 1"),
])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Exemoji architecture emulator")
    parser.add_argument("--memory", help="Initial memory content", type=str, default=None)
    parser.add_argument("--specs", help="Print the architecture specs and exit", action="store_true")
    parser.add_argument("--debug", help="Show debug info", action="store_true")
    parser.add_argument("--ydb", "-ydb", help="Run debugger", action="store_true")
    parser.add_argument("exec", help="File to execute", nargs="?")
    args = parser.parse_args()

    if args.specs:
        columns = [[], [], []]
        for instruction in instructions:
            columns[0].append(instruction.token)
            columns[1].append(instruction.c_like)
            columns[2].append(instruction.description)

        pads = [max(map(len, col)) for col in columns]
        print("| " + "Name".ljust(pads[0]) + " | " + "Pseudocode".ljust(pads[1]) + " | " + "Description".ljust(pads[2]) + " |")
        print("|" + "-" * (pads[0] + 2) + "|" + "-" * (pads[1] + 2) + "|" + "-" * (pads[2] + 2) + "|")
        for i in range(len(instructions)):
            print("| " + columns[0][i].ljust(pads[0]) + " | " + columns[1][i].ljust(pads[1]) + " | " + columns[2][i].ljust(pads[2]) + " |")
        sys.exit(0)

    if args.ydb:
        interpreter = Debugger(instructions)
    else:
        interpreter = Interpreter(instructions)

    try:
        with open(args.exec) as code:
            interpreter.init_execution(code)

        if args.memory:
            interpreter.machine.memory = list(map(ord, args.memory.replace("\\0", "\0")))

        while not interpreter.halted:
            interpreter.step()
        sys.stdout.write('\n')

    except YipYapError as e:
        sys.stderr.write("Error: %s:%s: %s\n" % (e.line, e.col, e))
        raise
    finally:
        if args.debug:
            for name, offset in interpreter.functions.items():
                print(name, interpreter.file_instructions[offset].token)

            debug_machine(interpreter.machine)
