#!/usr/bin/env python3
import io
import sys
import math
import enum
import argparse
import dataclasses
from typing import Callable


class InstructionArgument(enum.Enum):
    Register = 'reg'
    Immediate = 'imm'
    Emoji = 'emo'


@dataclasses.dataclass
class Instruction:
    opcode: str
    mnemonic: str
    name: str
    arguments: list[InstructionArgument]
    effect: Callable

    def __post_init__(self, *a):
        self.opcode = clean_emoji(self.opcode)


def aligned_size(size, align):
    if not align:
        return 0

    extra = size % align
    if extra:
        return align - extra
    return 0


def emoji_length(string):
    return len(string.encode("utf8"))


def check_same_length(items):
    target = emoji_length(items[0])
    for item in items[1:]:
        if emoji_length(item) != target:
            raise Emojiception("%s has length %s" % (item, emoji_length(item)))


def check_single(string):
    if len(string) > 1:
        raise Emojiception("Invalid emoji %s of length %s: %s" % (
            string, len(string), emoji_to_hexstring(string)
        ))


def emoji_to_hexstring(string):
    return " ".join("%04x" % ord(c) for c in string)


def clean_emoji(string):
    if len(string) == 2 and string[1] == '\ufe0f':
        return string[0]
    check_single(string)
    return string


class Emojiception(Exception):
    def __init__(self, *a, extra=None):
        super().__init__(*a)
        self.extra = extra


class Machine:
    def __init__(self, registers, immediates, long_immediate, instructions):
        self.registers = [0] * len(registers)
        self.immediates = immediates
        self.long_immediate = long_immediate
        self.memory = bytearray()
        self.stack_pointer = 0
        self.program_counter = 0
        self.halted = True
        self.regmoji = registers
        self.instructions = {
            instruction.opcode: instruction
            for instruction in instructions
        }
        self.stack_start = 0
        self.debug = False
        check_same_length(self.immediates)
        check_same_length(self.regmoji)
        check_single(self.long_immediate)

    def align(self, size):
        self.memory += bytearray(aligned_size(len(self.memory), size))

    def load_program(self, program):
        self.program_counter = len(self.memory)
        if isinstance(program, str):
            program = program.encode("utf8")
        self.memory += bytearray(program)

    def allocate_stack(self, memory):
        self.stack_start = len(self.memory)
        self.stack_pointer = len(self.memory)
        self.memory += bytearray(memory)

    def write_memory(self, address, byte):
        self.memory[address] = byte

    def push(self, byte):
        if self.stack_pointer >= len(self.memory):
            raise Emojiception("Stack overflow")
        self.write_memory(self.stack_pointer, byte)
        self.stack_pointer += 1

    def pop(self):
        self.stack_pointer -= 1
        if self.stack_pointer < self.stack_start:
            raise Emojiception("Stack underflow")
        return self.memory[self.stack_pointer]

    def push_register(self, reg):
        self.push(self.registers[reg])

    def pop_register(self, reg):
        self.registers[reg] = self.pop()

    def add(self):
        result = self.registers[0] + self.registers[1]
        self.registers[2] = result & 0xff
        self.registers[3] = (result >> 8) & 0xff

    def multiply(self):
        result = self.registers[0] * self.registers[1]
        self.registers[2] = result & 0xff
        self.registers[3] = (result >> 8) & 0xff

    def subtract(self):
        result = self.registers[0] - self.registers[1]
        self.registers[2] = result % 0x100
        self.registers[3] = 0 if result >= 0 else -result

    def divide(self):
        self.registers[2] = self.registers[0] // self.registers[1]
        self.registers[3] = self.registers[0] % self.registers[1]

    def bit_or(self):
        result = self.registers[0] | self.registers[1]
        self.registers[2] = result

    def bit_and(self):
        result = self.registers[0] & self.registers[1]
        self.registers[2] = result

    def bit_not(self):
        result = (~self.registers[0]) & 0xff
        self.registers[2] = result

    def bit_xor(self):
        result = self.registers[0] ^ self.registers[1]
        self.registers[2] = result

    def trig(self):
        angle = self.registers[0] / 128 * math.pi
        radius = self.registers[1]
        x = 127
        y = 127
        self.registers[2] = int(round(x + radius * math.cos(angle))) & 0xff
        self.registers[3] = int(round(y + radius * math.sin(angle))) & 0xff

    def jump(self, address, relative):
        if relative:
            self.program_counter += address
        else:
            self.program_counter = address

    def jump_zero(self, address, relative):
        if self.registers[2] == 0:
            self.jump(address, relative)

    def compare(self):
        if self.registers[0] < self.registers[1]:
            self.registers[2] = 255
        elif self.registers[0] > self.registers[1]:
            self.registers[2] = 1
        else:
            self.registers[2] = 0

    def load_immediate(self, regn, value):
        self.registers[regn] = value & 0xff

    def load_memory(self, reg_target, address):
        self.registers[reg_target] = self.memory[address]

    def store_memory(self, reg_target, address):
        if address < self.stack_start or address >= len(self.memory):
            raise Emojiception("Segmentation Fault (writing %04x)" % address)
        self.write_memory(address, self.registers[reg_target])

    def swap(self, rega, regb):
        self.registers[rega], self.registers[regb] = self.registers[regb], self.registers[rega]

    def move(self, reg_source, reg_target):
        self.registers[reg_target] = self.registers[reg_source]

    def interrupt(self, code):
        if code == 0:
            sys.stdout.write(chr(self.registers[2]))
            if self.debug:
                sys.stdout.write('\n')
        elif code == 0x0e:
            self.core_dump(True)
        elif code == 0x0f:
            self.core_dump(False)

    def _partial_to_emoji(self, start, stop, show_special=True):
        string = self.memory[start:stop]
        cleaned = bytearray()
        continuation = 0
        has_start = False
        for pos, byte in enumerate(string):
            # Continuation byte
            top_two = (byte & 0xc0)
            if top_two == 0x80:
                if has_start:
                    continuation -= 1
                    cleaned.append(byte)
                elif show_special:
                    cleaned.append(ord('.'))
            # Start multibyte
            elif top_two == 0xc0:
                has_start = True
                last_start = pos
                cleaned.append(byte)
                continuation = 0
                byte <<= 1
                while byte & 0x80:
                    byte <<= 1
                    continuation += 1
            # Ascii
            else:
                has_start = True
                continuation = 0
                if byte >= ord(' '):
                    cleaned.append(byte)
                elif byte == 0xa:
                    cleaned += b'\\' + b'n'
                elif show_special:
                    if byte == 0:
                        cleaned += b'\\0'
                    elif byte < ord(' '):
                        cleaned += ('\\x%02x' % byte).encode("ascii")

        if continuation:
            cleaned += self.memory[stop:stop + continuation]

        try:
            return cleaned.decode("utf8")
        except UnicodeDecodeError:
            return "?"

    def core_dump(self, dump_memory=True):
        sys.stdout.write("\n")
        for n, reg in enumerate(self.registers):
            self._format_value("reg%s %s" % (n, self.regmoji[n]), reg, True)

        self._format_value("PC", self.program_counter, False)
        self._format_value("SP", self.stack_pointer, False)

        if dump_memory:
            sys.stdout.write("\n      ")
            for i in range(16):
                sys.stdout.write("  %x " % i)
            for p, byte in enumerate(self.memory):
                if p % 16 == 0:
                    if p > 0 and p <= self.stack_start:
                        sys.stdout.write(self._partial_to_emoji(p - 16, p))
                    sys.stdout.write("\n%04x: " % p)

                if self.program_counter == p:
                    wrap = "><"
                elif self.stack_pointer == p:
                    wrap = "[]"
                else:
                    wrap = "  "

                sys.stdout.write("%s%02x%s" % (wrap[0], byte, wrap[1]))
            sys.stdout.write("\n")

    def _format_value(self, header, value, octet):
        print("%-8s%s%5d   %02x %s" % (
            header,
            "" if octet else " ",
            value,
            value,
            repr(chr(value)) if 0x20 <= value < 0x80 else ''
        ))

    def halt(self):
        self.halted = True

    def _read_byte(self):
        self.program_counter += 1
        return self.memory[self.program_counter - 1]

    def _read_emoji(self):
        byte = self._read_byte()
        emoji = [byte]
        if byte & 0x80:
            byte <<= 1
            while byte & 0x80:
                byte <<= 1
                emoji.append(self._read_byte())

        try:
            string = bytes(emoji).decode("utf8")
        except UnicodeDecodeError:
            raise Emojiception("Invalid instruction: %s" % " ".join("%02x" % c for c in emoji))

        if string == '\ufe0f' or string == '\ufe0e':
            return self._read_emoji()
        return string

    def _get_immediate_nibble(self, emoji):
        try:
            return self.immediates.index(emoji)
        except ValueError:
            raise Emojiception("Invalid immediate value %s" % emoji)

    def _read_immediate_impl(self, data):
        if data == self.long_immediate:
            high = self._get_immediate_nibble(self._read_emoji())
            low = self._get_immediate_nibble(self._read_emoji())
            return (high << 4) | low
        value = ord(data)
        if value >= 0x80:
            value = self._get_immediate_nibble(data)
        return value

    def read_immediate(self):
        data = self._read_emoji()
        return self._read_immediate_impl(data)

    def read_register(self):
        return self._reg_id(self._read_emoji())

    def read_argument(self, at):
        if at == InstructionArgument.Register:
            return self.read_register()
        elif at == InstructionArgument.Emoji:
            return self._read_emoji()
        else:
            return self.read_immediate()

    def _reg_id(self, emoji):
        try:
            return self.regmoji.index(emoji)
        except Exception:
            raise Emojiception("Invalid register %s" % emoji)

    def reg16(self, reg_high, reg_low):
        return (self.registers[reg_high] << 8) | self.registers[reg_low]

    def run(self):
        self.halted = False
        while not self.halted:
            instruction, args = self.fetch_instruction()

            if self.debug:
                print("exec %s %s" % (
                    instruction.opcode,
                    " ".join(self.format_operand(a, t) for a, t in zip(args, instruction.arguments))
                ))

            self.exec_instruction(instruction, args)

    def fetch_instruction(self):
        if self.halted:
            raise Emojiception("Halted")

        opcode = self._read_emoji()
        instruction = self.instructions.get(opcode)
        if not instruction:
            raise Emojiception("Illegal instruction %s" % opcode, extra=opcode)

        args = [
            self.read_argument(at)
            for at in instruction.arguments
        ]

        return (instruction, args)

    def exec_instruction(self, instruction, args):
        instruction.effect(self, *args)

    def find_emoji(self, emoji, backwards):
        if emoji in self.immediates or emoji == self.long_immediate:
            amount = self._read_immediate_impl(emoji)
            if backwards:
                amount = -amount
            return self.program_counter + amount

        coded = emoji.encode("utf8")
        if backwards:
            index = self.memory.rfind(coded, 0, self.program_counter - 1)
        else:
            index = self.memory.find(coded, self.program_counter)

        if index == -1:
            raise Emojiception("%s not found from %04x" % (emoji, self.program_counter))

        index += len(coded)
        return index

    def find_next(self, emoji):
        return self.find_emoji(emoji, False)

    def find_previous(self, emoji):
        return self.find_emoji(emoji, True)

    def format_operand(self, operand, type):
        if type == InstructionArgument.Register:
            return self.regmoji[operand]
        elif type == InstructionArgument.Immediate:
            return "%02x" % operand
        return operand


def build_machine(cls=Machine):
    Reg = InstructionArgument.Register
    Imm = InstructionArgument.Immediate
    Emo = InstructionArgument.Emoji
    I = Instruction

    return cls(
        ["🐞", "🐱", "🐲", "🐦", "🐯", "🦄", "🦜", "🐻"],
        # 0     1     2     3    4     5     6    7     8     9    a-10 b-11  c-12  d-13  e-14  f-15
        ["🍎", "🍌", "🍐", "🫐", "🍊", "🍇", "🍉", "🥝", "🍍", "🍓",
         "🍒", "🍋", "🍈", "🥥", "🥭", "🍆"],
        "🍽",
        [
            I("🐸", "swap", "Swap",         [Reg, Reg], lambda m, rega, regb:   m.swap(rega, regb)),
            I("🍑", "set",  "Assign",       [Reg, Imm], lambda m, reg, imm:     m.load_immediate(reg, imm)),
            I("🎁", "mov",  "Move",         [Reg, Reg], lambda m, regD, regS:   m.move(regS, regD)),
            I("🫸", "push", "Push",         [Reg],      lambda m, reg:          m.push_register(reg)),
            I("💥", "pop",  "Pop",          [Reg],      lambda m, reg:          m.pop_register(reg)),

            # I("😵", "ldpc", "Load from PC", [Reg, Reg],
                # lambda m, rDest, rOff:  m.load_memory(rDest, m.program_counter + m.registers[rOff])),
            I("💾", "ldsp", "Load from SP", [Reg, Reg],
                lambda m, rDest, rOff:  m.load_memory(rDest, m.stack_pointer - m.registers[rOff])),
            I("🧠", "ldl",  "Load Long",    [Reg, Reg, Reg], lambda m, rDest, rH, rL: m.load_memory(rDest, m.reg16(rH, rL))),
            I("🖊️", "stl",  "Store Long",   [Reg, Reg, Reg], lambda m, rDest, rH, rL: m.store_memory(rDest, m.reg16(rH, rL))),

            I("🍦️", "trig", "Trigonometry", [],         lambda m:               m.trig()),
            I("😢", "add",  "Add",          [],         lambda m:               m.add()),
            I("🦊", "sub",  "Subtract",     [],         lambda m:               m.subtract()),
            I("🐰", "mul",  "Multiply",     [],         lambda m:               m.multiply()),
            I("🖕", "div",  "Divide",       [],         lambda m:               m.divide()),
            I("🦁", "or",   "Or",           [],         lambda m:               m.bit_or()),
            I("🪢", "not",  "Not",          [],         lambda m:               m.bit_not()),
            I("🦝", "and",  "And",          [],         lambda m:               m.bit_and()),
            I("🍕", "xor",  "Xor",          [],         lambda m:               m.bit_xor()),
            I("🧐", "cmp",  "Compare",      [],         lambda m:               m.compare()),

            I("👉", "jrf",  "Jump forward", [Emo],      lambda m, emo:          m.jump(m.find_next(emo), False)),
            I("👈", "jrb",  "Jump back",    [Emo],      lambda m, emo:          m.jump(m.find_previous(emo), False)),
            I("🤌", "jrz",  "Jump 0 fwd",   [Emo],      lambda m, emo:          m.jump_zero(m.find_next(emo), False)),
            I("😐", "ign",  "Ignore",       [Emo],      lambda m, emo:          None),
            I("🦸", "jmp",  "Long Jump",    [Imm, Imm], lambda m, iH, iL:       m.jump((iH << 8) | iL, False)),

            I("🛫", "call", "Call",         [Imm, Imm], lambda m, iH, iL: (
                m.push((m.program_counter >> 8) & 0xff),
                m.push(m.program_counter & 0xff),
                m.jump((iH << 8) | iL, False)
            )),
            I("🛬", "ret",  "Pop PC",       [],         lambda m:             m.jump(m.pop() | (m.pop() << 8), False)),

            I("🫠", "halt", "Halt",         [],         lambda m:               m.halt()),
            I("🤖", "int",  "Inerrupt",     [Imm],      lambda m, imm:          m.interrupt(imm)),
        ]
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Exemoji architecture emulator")
    parser.add_argument("--memory", help="Stack memory to allocate in bytes", type=int, default=128)
    parser.add_argument("--specs", help="Print the architecture specs and exit", action="store_true")
    parser.add_argument("--debug", help="Show debug info", action="store_true")
    parser.add_argument("--dump", help="Dump the memory instead of running", action="store_true")
    parser.add_argument("--align", help="Input program alignment in bytes", type=int, default=16)
    parser.add_argument("--raw", action="store_true")
    parser.add_argument("exec", help="File to execute", nargs="?")
    args = parser.parse_args()

    machine = build_machine()

    if args.specs:
        print("Registers:")
        for i, e in enumerate(machine.regmoji):
            print("    reg%s %s" % (i, e))

        print("\nImmediates:")
        for i, e in enumerate(machine.immediates):
            print("    %x %s" % (i, e))

        print("\n2 nibble immediate marker:")
        print("    %s" % machine.long_immediate)

        print("\nInstructions:")
        name_max = 0
        for instr in machine.instructions.values():
            if len(instr.name) > name_max:
                name_max = len(instr.name)

        for instr in machine.instructions.values():
            print("    %s %s %s" % (instr.name.ljust(name_max), instr.opcode, " ".join(a.value for a in instr.arguments)))

        sys.exit(0)

    if args.raw:
        bytecode = args.exec.encode("utf8")
    elif args.exec == "-":
        bytecode = sys.stdin.read().encode("utf8")
    else:
        with open(args.exec, "rb") as f:
            bytecode = f.read()

    machine.load_program(bytecode)
    machine.align(args.align)
    machine.allocate_stack(args.memory)
    machine.debug = args.debug

    if args.dump:
        machine.core_dump()
        sys.exit(0)

    try:
        machine.run()
        if machine.debug:
            machine.core_dump()
    except BaseException:
        machine.core_dump()
        raise
