import itertools
from random import randint
import time
import inspect
import functools
import copy
import sys
import struct
from Compiler.exceptions import *
from Compiler.config import *
from Compiler import util
from Compiler import tools
from Compiler import program


###
### Opcode constants
###
### Whenever these are changed the corresponding enums in Processor/instruction.h
### MUST also be changed. (+ the documentation)
###
opcodes = dict(
    # Emulation
    CISC = 0,
    # Load/store
    LDI = 0x1,
    LDSI = 0x2,
    LDMC = 0x3,
    LDMS = 0x4,
    STMC = 0x5,
    STMS = 0x6,
    LDMCI = 0x7,
    LDMSI = 0x8,
    STMCI = 0x9,
    STMSI = 0xA,
    MOVC = 0xB,
    MOVS = 0xC,
    PROTECTMEMS = 0xD,
    PROTECTMEMC = 0xE,
    PROTECTMEMINT = 0xF,
    LDMINT = 0xCA,
    STMINT = 0xCB,
    LDMINTI = 0xCC,
    STMINTI = 0xCD,
    PUSHINT = 0xCE,
    POPINT = 0xCF,
    MOVINT = 0xD0,
    # Machine
    LDTN = 0x10,
    LDARG = 0x11,
    REQBL = 0x12,
    STARG = 0x13,
    TIME = 0x14,
    START = 0x15,
    STOP = 0x16,
    USE = 0x17,
    USE_INP = 0x18,
    RUN_TAPE = 0x19,
    JOIN_TAPE = 0x1A,
    CRASH = 0x1B,
    USE_PREP = 0x1C,
    STARTGRIND = 0x1D,
    STOPGRIND = 0x1E,
    NPLAYERS = 0xE2,
    THRESHOLD = 0xE3,
    PLAYERID = 0xE4,
    USE_EDABIT = 0xE5,
    USE_MATMUL = 0x1F,
    ACTIVE = 0xE9,
    CMDLINEARG = 0xEB,
    CALL_TAPE = 0xEC,
    CALL_ARG = 0xED,
    # Addition
    ADDC = 0x20,
    ADDS = 0x21,
    ADDM = 0x22,
    ADDCI = 0x23,
    ADDSI = 0x24,
    SUBC = 0x25,
    SUBS = 0x26,
    SUBML = 0x27,
    SUBMR = 0x28,
    SUBCI = 0x29,
    SUBSI = 0x2A,
    SUBCFI = 0x2B,
    SUBSFI = 0x2C,
    PREFIXSUMS = 0x2D,
    PICKS = 0x2E,
    CONCATS = 0x2F,
    ZIPS = 0x3F,
    # Multiplication/division
    MULC = 0x30,
    MULM = 0x31,
    MULCI = 0x32,
    MULSI = 0x33,
    DIVC = 0x34,
    DIVCI = 0x35,
    MODC = 0x36,
    MODCI = 0x37,
    LEGENDREC = 0x38,
    DIGESTC = 0x39,
    INV2M = 0x3a,
    FLOORDIVC = 0x3b,
    GMULBITC = 0x136,
    GMULBITM = 0x137,
    # Open
    OPEN = 0xA5,
    MULS = 0xA6,
    MULRS = 0xA7,
    DOTPRODS = 0xA8,
    TRUNC_PR = 0xA9,
    MATMULS = 0xAA,
    MATMULSM = 0xAB,
    CONV2DS = 0xAC,
    CHECK = 0xAF,
    PRIVATEOUTPUT = 0xAD,
    # Shuffling
    SECSHUFFLE = 0xFA,
    GENSECSHUFFLE = 0xFB,
    APPLYSHUFFLE = 0xFC,
    DELSHUFFLE = 0xFD,
    INVPERM = 0xFE,
    # Data access
    TRIPLE = 0x50,
    BIT = 0x51,
    SQUARE = 0x52,
    INV = 0x53,
    GBITTRIPLE = 0x154,
    GBITGF2NTRIPLE = 0x155,
    INPUTMASK = 0x56,
    INPUTMASKREG = 0x5C,
    PREP = 0x57,
    DABIT = 0x58,
    EDABIT = 0x59,
    SEDABIT = 0x5A,
    RANDOMS = 0x5B,
    RANDOMFULLS = 0x5D,
    UNSPLIT = 0x5E,
    # Input
    INPUT = 0x60,
    INPUTFIX = 0xF0,
    INPUTFLOAT = 0xF1,
    INPUTMIXED = 0xF2,
    INPUTMIXEDREG = 0xF3,
    RAWINPUT = 0xF4,
    INPUTPERSONAL = 0xF5,
    SENDPERSONAL = 0xF6,
    STARTINPUT = 0x61,
    STOPINPUT = 0x62,  
    READSOCKETC = 0x63,
    READSOCKETS = 0x64,
    WRITESOCKETC = 0x65,
    WRITESOCKETS = 0x66,
    READSOCKETINT = 0x69,
    WRITESOCKETINT = 0x6a,
    WRITESOCKETSHARE = 0x6b,
    LISTEN = 0x6c,
    ACCEPTCLIENTCONNECTION = 0x6d,
    CLOSECLIENTCONNECTION = 0x6e,
    INITCLIENTCONNECTION = 0x6f,
    # Bitwise logic
    ANDC = 0x70,
    XORC = 0x71,
    ORC = 0x72,
    ANDCI = 0x73,
    XORCI = 0x74,
    ORCI = 0x75,
    NOTC = 0x76,
    # Bitwise shifts
    SHLC = 0x80,
    SHRC = 0x81,
    SHLCI = 0x82,
    SHRCI = 0x83,
    SHRSI = 0x84,
    # Branching and comparison
    JMP = 0x90,
    JMPNZ = 0x91,
    JMPEQZ = 0x92,
    EQZC = 0x93,
    LTZC = 0x94,
    LTC = 0x95,
    GTC = 0x96,
    EQC = 0x97,
    JMPI = 0x98,
    # Integers
    BITDECINT = 0x99,
    LDINT = 0x9A,
    ADDINT = 0x9B,
    SUBINT = 0x9C,
    MULINT = 0x9D,
    DIVINT = 0x9E,
    PRINTINT = 0x9F,
    INCINT = 0xD1,
    SHUFFLE = 0xD2,
    # Conversion
    CONVINT = 0xC0,
    CONVMODP = 0xC1,
    GCONVGF2N = 0x1C1,
    # IO
    PRINTMEM = 0xB0,
    PRINTREG = 0XB1,
    RAND = 0xB2,
    PRINTREGPLAIN = 0xB3,
    PRINTREGPLAINS = 0xEA,
    PRINTCHR = 0xB4,
    PRINTSTR = 0xB5,
    PUBINPUT = 0xB6,
    RAWOUTPUT = 0xB7,
    STARTPRIVATEOUTPUT = 0xB8,
    STOPPRIVATEOUTPUT = 0xB9,
    PRINTCHRINT = 0xBA,
    PRINTSTRINT = 0xBB,
    PRINTFLOATPLAIN = 0xBC,
    WRITEFILESHARE = 0xBD,     
    READFILESHARE = 0xBE,
    CONDPRINTSTR = 0xBF,
    PRINTFLOATPREC = 0xE0,
    CONDPRINTPLAIN = 0xE1,
    INTOUTPUT = 0xE6,
    FLOATOUTPUT = 0xE7,
    FIXINPUT = 0xE8,
    GBITDEC = 0x18A,
    GBITCOM = 0x18B,
    # Secure socket
    INITSECURESOCKET = 0x1BA,
    RESPSECURESOCKET = 0x1BB
)


vm_types = dict(
    ci = 0,
    sb = 1,
    cb = 2,
    s = 4,
    c = 5,
    sg = 6,
    cg = 7,
)


def int_to_bytes(x):
    """ 32 bit int to big-endian 4 byte conversion. """
    assert(x < 2**32 and x >= -2**32)
    return [(x >> 8*i) % 256 for i in (3,2,1,0)]


global_vector_size_stack = []
global_instruction_type_stack = ['modp']

def check_vector_size(size):
    if isinstance(size, program.curr_tape.Register):
        raise CompilerError('vector size must be known at compile time')

def set_global_vector_size(size):
    stack = global_vector_size_stack
    check_vector_size(size)
    if size == 1 and not stack:
        return
    stack.append(size)

def set_global_instruction_type(t):
    if t == 'modp' or t == 'gf2n':
        global_instruction_type_stack.append(t)
    else:
        raise CompilerError('Invalid type %s for setting global instruction type')

def reset_global_vector_size():
    stack = global_vector_size_stack
    if global_vector_size_stack:
        stack.pop()

def reset_global_instruction_type():
    global_instruction_type_stack.pop()

def get_global_vector_size():
    stack = global_vector_size_stack
    if stack:
        return stack[-1]
    else:
        return 1

def get_global_instruction_type():
    return global_instruction_type_stack[-1]


def vectorize(instruction, global_dict=None):
    """ Decorator to vectorize instructions. """

    if global_dict is None:
        global_dict = inspect.getmodule(instruction).__dict__

    class Vectorized_Instruction(instruction):
        __slots__ = ['size']
        def __init__(self, size, *args, **kwargs):
            self.size = size
            super(Vectorized_Instruction, self).__init__(*args, **kwargs)
            if not kwargs.get('copying', False):
                for arg,f in zip(self.args, self.arg_format):
                    if issubclass(ArgFormats[f], RegisterArgFormat):
                        arg.set_size(size)
        def get_code(self):
            return instruction.get_code(self, self.get_size())
        def get_pre_arg(self):
            try:
                return "%d, " % self.size
            except:
                return "{undef}, "
        def is_vec(self):
            return True
        def get_size(self):
            return self.size
        def expand(self):
            set_global_vector_size(self.size)
            super(Vectorized_Instruction, self).expand()
            reset_global_vector_size()
        def copy(self, size, subs):
            return type(self)(size, *self.get_new_args(size, subs),
                              copying=True)

    @functools.wraps(instruction)
    def maybe_vectorized_instruction(*args, **kwargs):
        size = get_global_vector_size()
        for arg in args:
            try:
                size = arg.size
                break
            except:
                pass
        if size == 1:
            return instruction(*args, **kwargs)
        else:
            return Vectorized_Instruction(size, *args, **kwargs)
    maybe_vectorized_instruction.vec_ins = Vectorized_Instruction
    maybe_vectorized_instruction.std_ins = instruction
    
    vectorized_name = 'v' + instruction.__name__
    Vectorized_Instruction.__name__ = vectorized_name
    global_dict[vectorized_name] = Vectorized_Instruction

    if 'sphinx.extension' in sys.modules:
        return instruction

    global_dict[instruction.__name__ + '_class'] = instruction
    maybe_vectorized_instruction.arg_format = instruction.arg_format
    return maybe_vectorized_instruction


def gf2n(instruction):
    """ Decorator to create GF_2^n instruction corresponding to a given
        modp instruction.

        Adds the new GF_2^n instruction to the globals dictionary. Also adds a
        vectorized GF_2^n instruction if a modp version exists. """
    global_dict = inspect.getmodule(instruction).__dict__

    if 'v' + instruction.__name__ in global_dict:
        vectorized = True
    else:
        vectorized = False

    if isinstance(instruction, type) and issubclass(instruction, Instruction):
        instruction_cls = instruction
    else:
        try:
            instruction_cls = global_dict[instruction.__name__ + '_class']
        except KeyError:
            raise CompilerError('Cannot decorate instruction %s' % instruction)

    def reformat(arg_format):
        if isinstance(arg_format, list):
            __format = []
            for __f in arg_format:
                if __f in ('int', 'long', 'p', 'ci', 'ciw', 'str'):
                    __format.append(__f)
                else:
                    __format.append(__f[0] + 'g' + __f[1:])
            arg_format[:] = __format
        else:
            for __f in arg_format.args:
                reformat(__f)

    class GF2N_Instruction(instruction_cls):
        __doc__ = instruction_cls.__doc__.replace('c_', 'c^g_').replace('s_', 's^g_')
        __slots__ = []
        field_type = 'gf2n'
        if isinstance(instruction_cls.code, int):
            code = (1 << 8) + instruction_cls.code

        # set modp registers in arg_format to GF2N registers
        if 'gf2n_arg_format' in instruction_cls.__dict__:
            arg_format = instruction_cls.gf2n_arg_format
        elif isinstance(instruction_cls.arg_format, itertools.repeat):
            __f = next(instruction_cls.arg_format)
            if __f not in ('int', 'long', 'p'):
                arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
        elif isinstance(instruction_cls.arg_format, property):
            pass
        else:
            arg_format = copy.deepcopy(instruction_cls.arg_format)
            reformat(arg_format)

        @classmethod
        def is_gf2n(self):
            return True

        def expand(self):
            set_global_instruction_type('gf2n')
            super(GF2N_Instruction, self).expand()
            reset_global_instruction_type()

    GF2N_Instruction.__name__ = 'g' + instruction_cls.__name__
    if vectorized:
        vec_GF2N = vectorize(GF2N_Instruction, global_dict)

    @functools.wraps(instruction)
    def maybe_gf2n_instruction(*args, **kwargs):
        if get_global_instruction_type() == 'gf2n':
            if vectorized:
                return vec_GF2N(*args, **kwargs)
            else:
                return GF2N_Instruction(*args, **kwargs)
        else:
            return instruction(*args, **kwargs)
    
    # If instruction is vectorized, new GF2N instruction must also be
    if vectorized:
        global_dict[GF2N_Instruction.__name__] = vec_GF2N
    else:
        global_dict[GF2N_Instruction.__name__] = GF2N_Instruction

    if 'sphinx.extension' in sys.modules:
        return instruction

    global_dict[instruction.__name__ + '_class'] = instruction_cls
    maybe_gf2n_instruction.arg_format = instruction.arg_format
    return maybe_gf2n_instruction
    #return instruction

class Mergeable:
    pass

def cisc(function, n_outputs=1):
    class MergeCISC(Mergeable):
        instructions = {}
        functions = {}

        def __init__(self, *args, **kwargs):
            self.args = args
            self.kwargs = kwargs
            self.security = program._security
            self.calls = [(args, kwargs)]
            self.params = []
            self.used = []
            for arg in self.args[n_outputs:]:
                if isinstance(arg, program.curr_tape.Register):
                    self.used.append(arg)
                    self.params.append(type(arg))
                else:
                    self.params.append(arg)
            self.function = function
            self.caller = None
            program.curr_block.instructions.append(self)

        def get_def(self):
            return sum(([call[0][i] for call in self.calls]
                        for i in range(n_outputs)), [])

        def get_used(self):
            return self.used

        def is_vec(self):
            return True

        def merge_id(self):
            return self.function, tuple(self.params), \
                tuple(sorted(self.kwargs.items())), self.security

        def merge(self, other):
            self.calls += other.calls
            self.used += other.used

        def get_size(self):
            return self.args[0].vector_size()

        def new_instructions(self, size, regs):
            if self.merge_id() not in self.instructions:
                from Compiler.program import Tape
                tape = Tape(self.function.__name__, program)
                old_tape = program.curr_tape
                program.curr_tape = tape
                block = tape.BasicBlock(tape, None, None)
                tape.active_basicblock = block
                set_global_vector_size(None)
                args = []
                for arg in self.args:
                    try:
                        args.append(arg.new_vector(size=None))
                    except:
                        args.append(arg)
                program.options.cisc = False
                old_security = program._security
                program.security = self.security
                self.function(*args, **self.kwargs)
                program.security = old_security
                program.options.cisc = True
                reset_global_vector_size()
                program.curr_tape = old_tape
                for x, bl in tape.req_bit_length.items():
                    old_tape.require_bit_length(
                        bl - 1, x, tape.bit_length_reason if x == 'p' else '')
                from Compiler.allocator import Merger
                merger = Merger(block, program.options,
                                tuple(program.to_merge))
                for i in range(n_outputs):
                    args[i].can_eliminate = False
                merger.eliminate_dead_code()
                assert int(program.options.max_parallel_open) == 0, \
                    'merging restriction not compatible with ' \
                    'mergeable CISC instructions'
                n_rounds = merger.longest_paths_merge()
                filtered = filter(lambda x: x is not None, block.instructions)
                self.instructions[self.merge_id()] = list(filtered), args, \
                                                     n_rounds
            template, args, self.n_rounds = self.instructions[self.merge_id()]
            subs = util.dict_by_id()
            from Compiler import types
            for arg, reg in zip(args, regs):
                if isinstance(arg, program.curr_tape.Register):
                    subs[arg] = reg
            set_global_vector_size(size)
            for inst in template:
                inst.copy(size, subs)
            reset_global_vector_size()

        class Arg:
            def __init__(self, reg):
                from Compiler.GC.types import bits
                self.type = type(reg)
                self.binary = isinstance(reg, bits)
                self.reg = reg
            def new(self, size):
                if self.binary:
                    return self.type()
                else:
                    return self.type(size=size)
            def load(self):
                return self.reg
            def store(self, reg):
                if self.type != type(None):
                    self.reg.update(reg)
            def is_real(self):
                return self.reg is not None

        def base_key(self, size, new_regs):
            return size, tuple(
                arg for arg, reg in zip(self.args, new_regs) if reg is None), \
                tuple(type(reg) for reg in new_regs)

        @staticmethod
        def get_name(key):
            return '_'.join(['%s(%d)' % (function.__name__, key[0])] +
                            [str(x) for x in key[1]])

        def expand_to_function(self, size, new_regs):
            key = self.base_key(size, new_regs) + (program.curr_tape,)
            if key not in self.functions:
                args = [self.Arg(x) for x in new_regs]
                from Compiler import library, types
                @library.function_block
                def f():
                    res = [arg.new(size) for arg in args[:n_outputs]]
                    self.new_instructions(
                        size, res + [arg.load() for arg in args[n_outputs:]])
                    for reg, arg in zip(res, args):
                        arg.store(reg)
                f.name = self.get_name(key)
                self.functions[key] = f, args
            f, args = self.functions[key]
            for i in range(len(new_regs) - n_outputs):
                args[n_outputs + i].store(new_regs[n_outputs + i])
            f()
            for i in range(n_outputs):
                new_regs[i].link(args[i].load())

        def expand_to_tape(self, size, new_regs):
            key = self.base_key(size, new_regs)
            args = [self.Arg(x) for x in new_regs]
            if key not in self.functions:
                from Compiler import library, types
                @library.function_call_tape
                def f(*in_args):
                    res = [arg.new(size) for arg in args[:n_outputs]]
                    in_args = list(in_args)
                    my_args = list(res)
                    for arg in args[n_outputs:]:
                        if arg.is_real():
                            my_args.append(in_args.pop(0))
                        else:
                            my_args.append(arg.reg)
                    self.new_instructions(size, my_args)
                    return res
                f.name =  self.get_name(key)
                self.functions[key] = f
            f = self.functions[key]
            in_args = filter(lambda arg: arg.is_real(), args[n_outputs:])
            res = util.tuplify(f(*(arg.load() for arg in in_args)))
            for i in range(n_outputs):
                new_regs[i].link(res[i])

        def expand_merged(self, skip):
            if function.__name__ in skip:
                good = True
                for call in self.calls:
                    if not good:
                        break
                    for i in range(n_outputs):
                        for arg in call[0]:
                            if isinstance(arg, program.curr_tape.Register) and \
                               not issubclass(type(self.calls[0][0][0]),
                                              type(arg)):
                                good = False
                if good:
                    return program.curr_block.instructions.append(self)
            if program.verbose:
                print('expanding', self.function.__name__)
            tape = program.curr_tape
            tape.start_new_basicblock(name='pre-' + self.name())
            size = sum(call[0][0].vector_size() for call in self.calls)
            new_regs = []
            for i, arg in enumerate(self.args):
                try:
                    if i < n_outputs:
                        new_regs.append(arg.new_vector(size=size))
                    else:
                        new_regs.append(type(arg).concat(
                            call[0][i] for call in self.calls))
                        assert new_regs[-1].vector_size() == size
                except (TypeError, AttributeError):
                    if not isinstance(arg, (int, type(None))):
                        raise
                    new_regs.append(None)
                except:
                    print([call[0][0].vector_size() for call in self.calls])
                    raise
            if program.cisc_to_function and \
               (program.curr_tape.singular or program.n_running_threads):
                if not program.use_tape_calls and not program.force_cisc_tape:
                    self.expand_to_function(size, new_regs)
                else:
                    self.expand_to_tape(size, new_regs)
            else:
                self.new_instructions(size, new_regs)
                program.curr_block.n_rounds += self.n_rounds - 1
            base = 0
            for call in self.calls:
                for i in range(n_outputs):
                    reg = call[0][i]
                    reg.copy_from_part(new_regs[i], base, reg.vector_size())
                base += reg.vector_size()
            tape.start_new_basicblock(name='post-' + self.name())

        def add_usage(self, *args):
            pass

        def get_bytes(self):
            assert len(self.kwargs) < 2
            res = LongArgFormat.encode(opcodes['CISC'])
            res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1)
            name = self.function.__name__
            String.check(name)
            res += String.encode(name)
            for call in self.calls:
                call[1].pop('nearest', None)
                assert not call[1]
                res += int_to_bytes(len(call[0]) + 2)
                res += int_to_bytes(call[0][0].size)
                for arg in call[0]:
                    res += self.arg_to_bytes(arg)
            return bytearray(res)

        @classmethod
        def arg_to_bytes(self, arg):
            if arg is None:
                return int_to_bytes(0)
            try:
                return int_to_bytes(arg.i)
            except:
                return int_to_bytes(arg)

        def name(self):
            return self.function.__name__

        def __str__(self):
            return self.function.__name__ + ' ' + ', '.join(
                str(x) for x in itertools.chain(call[0] for call in self.calls))

    MergeCISC.__name__ = function.__name__

    def wrapper(*args, **kwargs):
        same_sizes = True
        for arg in args:
            try:
                same_sizes &= arg.size == args[0].size
            except:
                pass
        if program.use_cisc() and same_sizes:
            return MergeCISC(*args, **kwargs)
        else:
            return function(*args, **kwargs)
    return wrapper

def ret_cisc(function):
    def instruction(res, *args, **kwargs):
        res.mov(res, function(*args, **kwargs))
    instruction.__name__ = function.__name__
    instruction = cisc(instruction)

    def wrapper(*args, **kwargs):
        from Compiler import types
        if not (program.options.cisc and isinstance(args[0], types._register)):
            return function(*args, **kwargs)
        for arg in args:
            if isinstance(arg, types._secret):
                res_type = type(arg)
                break
        res = res_type(size=args[0].size)
        instruction(res, *args, **kwargs)
        return res
    return wrapper

def sfix_cisc(function):
    from Compiler.types import sfix, sint, cfix, copy_doc
    def instruction(res, arg, k, f, *args):
        assert k is not None
        assert f is not None
        old = sfix.k, sfix.f, cfix.k, cfix.f
        sfix.k, sfix.f, cfix.k, cfix.f = [None] * 4
        res.mov(res, function(sfix._new(arg, k=k, f=f), *args).v)
        sfix.k, sfix.f, cfix.k, cfix.f = old
    instruction.__name__ = function.__name__
    instruction = cisc(instruction)

    def wrapper(*args, **kwargs):
        if isinstance(args[0], sfix) and program.options.cisc:
            for arg in args[1:]:
                assert util.is_constant(arg)
            assert not kwargs
            assert args[0].size == args[0].v.size
            k = args[0].k
            f = args[0].f
            res = sfix._new(sint(size=args[0].size), k=k, f=f)
            instruction(res.v, args[0].v, k, f, *args[1:])
            return res
        else:
            return function(*args, **kwargs)
    copy_doc(wrapper, function)
    return wrapper

bit_instructions = {}

def bit_cisc(function):
    def wrapper(a, k, m, *args, **kwargs):
        key = function, m
        if key not in bit_instructions:
            def instruction(*args, **kwargs):
                res = function(*args[m:], **kwargs)
                for x, y in zip(res, args):
                    x.mov(y, x)
            instruction.__name__ = '%s(%d)' % (function.__name__, m)
            bit_instructions[key] = cisc(instruction, m)
        from Compiler.types import sintbit
        res = [sintbit() for i in range(m)]
        bit_instructions[function, m](*res, a, k, m, *args, **kwargs)
        return res
    return wrapper

class RegType(object):
    """ enum-like static class for Register types """
    ClearModp = 'c'
    SecretModp = 's'
    ClearGF2N = 'cg'
    SecretGF2N = 'sg'
    ClearInt = 'ci'

    Types = [ClearModp, SecretModp, ClearGF2N, SecretGF2N, ClearInt]

    @staticmethod
    def create_dict(init_value_fn):
        """ Create a dictionary with all the RegTypes as keys """
        res = defaultdict(init_value_fn)
        # initialization for legacy
        for t in RegType.Types:
            res[t]
        return res

class ArgFormat(object):
    is_reg = False

    @classmethod
    def check(cls, arg):
        return NotImplemented

    @classmethod
    def encode(cls, arg):
        return NotImplemented

class RegisterArgFormat(ArgFormat):
    is_reg = True

    @classmethod
    def check(cls, arg):
        if not isinstance(arg, program.curr_tape.Register):
            raise ArgumentError(arg, 'Invalid register argument')
        if arg.program != program.curr_tape:
            raise ArgumentError(arg, 'Register from other tape, trace: %s' % \
                                    util.format_trace(arg.caller) +
                                '\nMaybe use MemValue')
        if arg.reg_type != cls.reg_type:
            raise ArgumentError(arg, "Wrong register type '%s', expected '%s'" % \
                                    (arg.reg_type, cls.reg_type))

    @classmethod
    def encode(cls, arg):
        assert arg.i >= 0
        return int_to_bytes(arg.i)

    def __init__(self, f):
        self.i = struct.unpack('>I', f.read(4))[0]

    def __str__(self):
        return self.reg_type + str(self.i)

class ClearModpAF(RegisterArgFormat):
    reg_type = RegType.ClearModp
    name = 'cint'

class SecretModpAF(RegisterArgFormat):
    reg_type = RegType.SecretModp
    name = 'sint'

class ClearGF2NAF(RegisterArgFormat):
    reg_type = RegType.ClearGF2N
    name = 'cgf2n'

class SecretGF2NAF(RegisterArgFormat):
    reg_type = RegType.SecretGF2N
    name = 'sgf2n'

class ClearIntAF(RegisterArgFormat):
    reg_type = RegType.ClearInt
    name = 'regint'

class AnyRegAF(RegisterArgFormat):
    reg_type = '*'
    @staticmethod
    def check(arg):
        assert isinstance(arg, program.curr_tape.Register)

class IntArgFormat(ArgFormat):
    n_bits = 32

    @classmethod
    def check(cls, arg):
        if not arg is None:
            if not isinstance(arg, int):
                raise ArgumentError(arg, 'Expected an integer-valued argument')
            if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits:
                raise ArgumentError(
                    arg, 'Immediate value outside of %d-bit range' % cls.n_bits)

    @classmethod
    def encode(cls, arg):
        return int_to_bytes(arg)

    def __init__(self, f):
        self.i = struct.unpack('>i', f.read(4))[0]

    def __str__(self):
        return str(self.i)

class LongArgFormat(IntArgFormat):
    n_bits = 64

    @classmethod
    def encode(cls, arg):
        return list(struct.pack('>q', arg))

    def __init__(self, f):
        self.i = struct.unpack('>q', f.read(8))[0]

class ImmediateModpAF(IntArgFormat):
    @classmethod
    def check(cls, arg):
        super(ImmediateModpAF, cls).check(arg)

class ImmediateGF2NAF(IntArgFormat):
    @classmethod
    def check(cls, arg):
        # bounds checking for GF(2^n)???
        super(ImmediateGF2NAF, cls).check(arg)

class PlayerNoAF(IntArgFormat):
    @classmethod
    def check(cls, arg):
        if not util.is_constant(arg):
            raise CompilerError('Player number must be known at compile time')
        super(PlayerNoAF, cls).check(arg)
        if arg > 256:
            raise ArgumentError(arg, 'Player number > 256')

class String(ArgFormat):
    length = 16

    @classmethod
    def check(cls, arg):
        if not isinstance(arg, str):
            raise ArgumentError(arg, 'Argument is not string')
        if len(arg) > cls.length:
            raise ArgumentError(arg, 'String longer than %d' % cls.length)
        if '\0' in arg:
            raise ArgumentError(arg, 'String contains zero-byte')

    @classmethod
    def encode(cls, arg):
        return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))

    def __init__(self, f):
        tmp = f.read(16)
        self.str = str(tmp[0:tmp.find(b'\0')], 'ascii')

    def __str__(self):
        return self.str

class VarString(ArgFormat):
    @classmethod
    def check(cls, arg):
        if not isinstance(arg, str):
            raise ArgumentError(arg, 'Argument is not string')

    @classmethod
    def encode(cls, arg):
        return int_to_bytes(len(arg)) + list(bytearray(arg, 'ascii'))

    def __init__(self, f):
        length = IntArgFormat(f).i
        self.str = str(f.read(length), 'ascii')

    def __str__(self):
        return self.str

ArgFormats = {
    'c': ClearModpAF,
    's': SecretModpAF,
    'cw': ClearModpAF,
    'sw': SecretModpAF,
    'cg': ClearGF2NAF,
    'sg': SecretGF2NAF,
    'cgw': ClearGF2NAF,
    'sgw': SecretGF2NAF,
    'ci': ClearIntAF,
    'ciw': ClearIntAF,
    '*': AnyRegAF,
    '*w': AnyRegAF,
    'i': ImmediateModpAF,
    'ig': ImmediateGF2NAF,
    'int': IntArgFormat,
    'long': LongArgFormat,
    'p': PlayerNoAF,
    'str': String,
    'varstr': VarString,
}

def format_str_is_reg(format_str):
    return ArgFormats[format_str].is_reg

def format_str_is_writeable(format_str):
    return format_str_is_reg(format_str) and format_str[-1] == 'w'


class Instruction(object):
    """
    Base class for a RISC-type instruction. Has methods for checking arguments,
    getting byte encoding, emulating the instruction, etc.
    """
    __slots__ = ['args', 'arg_format', 'code', 'caller']
    count = 0
    code_length = 10

    def __init__(self, *args, **kwargs):
        """ Create an instruction and append it to the program list. """
        self.args = list(args)
        if not kwargs.get('copying', False):
            self.check_args()
        if kwargs.get('add_to_prog', True):
            program.curr_block.instructions.append(self)
        if program.DEBUG:
            self.caller = [frame[1:] for frame in inspect.stack()[1:]]
        else:
            self.caller = None
        
        Instruction.count += 1
        if Instruction.count % 100000 == 0:
            print("Compiled %d lines at" % self.__class__.count, time.asctime())
            sys.stdout.flush()
            if Instruction.count > 10 ** 7:
                print("Compilation produced more that 10 million instructions. "
                      "Consider using './compile.py -l' or replacing for loops "
                      "with @for_range_opt: "
                      "https://mp-spdz.readthedocs.io/en/latest/Compiler.html#"
                      "Compiler.library.for_range_opt")

    def get_code(self, prefix=0):
        return (prefix << self.code_length) + self.code

    def get_encoding(self):
        enc = LongArgFormat.encode(self.get_code())
        # add the number of registers if instruction flagged as has var args
        if self.has_var_args():
            enc += int_to_bytes(len(self.args))
        for arg,format in zip(self.args, self.arg_format):
            enc += ArgFormats[format].encode(arg)
        return enc
    
    def get_bytes(self):
        try:
            return bytearray(self.get_encoding())
        except TypeError:
            raise CompilerError('cannot encode %s/%s' % (self, self.get_encoding()))
    
    def check_args(self):
        """ Check the args match up with that specified in arg_format """
        try:
            if len(self.args) != len(self.arg_format):
                raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
        except TypeError:
            pass
        for n,(arg,f) in enumerate(zip(self.args, self.arg_format)):
            try:
                ArgFormats[f].check(arg)
            except ArgumentError as e:
                raise CompilerError('Invalid argument %d "%s" to instruction: %s'
                    % (n, e.arg, self) + '\n' + e.msg)
            except KeyError as e:
                raise CompilerError('Unknown argument %s for instruction %s' % (f, self))
    
    def get_used(self):
        """ Return the set of registers that are read in this instruction. """
        return (arg for arg,w in zip(self.args, self.arg_format) if \
            format_str_is_reg(w) and not format_str_is_writeable(w))
    
    def get_def(self):
        """ Return the set of registers that are written to in this instruction. """
        return (arg for arg,w in zip(self.args, self.arg_format) if \
            format_str_is_writeable(w))
    
    def get_pre_arg(self):
        return ""

    def has_var_args(self):
        try:
            len(self.arg_format)
            return False
        except:
            return True

    def is_vec(self):
        return False

    @classmethod
    def is_gf2n(self):
        return False

    def get_size(self):
        return 1

    def add_usage(self, req_node):
        pass

    def merge_id(self):
        return type(self), self.get_size()

    def merge(self, other):
        if self.get_size() != other.get_size():
            # merge as non-vector instruction
            self.args = self.expand_vector_args() + other.expand_vector_args()
            if self.is_vec():
                self.size = 1
        else:
            self.args += other.args

    def expand_vector_args(self):
        if self.is_vec() and self.get_size() != 1:
            for arg in self.args:
                arg.create_vector_elements()
                res = sum(list(zip(*self.args)), ())
                return list(res)
        else:
            return self.args

    def expand_merged(self, skip):
        program.curr_block.instructions.append(self)

    def get_new_args(self, size, subs):
        new_args = []
        for arg, f in zip(self.args, self.arg_format):
            if arg in subs:
                new_args.append(subs[arg])
            elif arg is None:
                new_args.append(size)
            else:
                if format_str_is_writeable(f):
                    new_args.append(arg.copy())
                    subs[arg] = new_args[-1]
                else:
                    new_args.append(arg)
        return new_args

    def copy(self, *args, **kwargs):
        raise CompilerError("%s instruction not compatible with CISC-style "
                            "merging. Compile with '-O'." % type(self))

    @staticmethod
    def get_usage(args):
        return {}

    # String version of instruction attempting to replicate encoded version
    def __str__(self):
        
        if self.has_var_args():
            varargCount = str(len(self.args)) + ', '
        else:
            varargCount = ''

        return self.__class__.__name__ + ' ' + self.get_pre_arg() + varargCount + ', '.join(str(a) for a in self.args)

    def __repr__(self):
        return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'

class ParsedInstruction:
    reverse_opcodes = {}

    def __init__(self, f):
        cls = type(self)
        from Compiler import instructions
        from Compiler.GC import instructions as gc_inst
        if not cls.reverse_opcodes:
            for module in instructions, gc_inst:
                for x, y in inspect.getmodule(module).__dict__.items():
                    if inspect.isclass(y) and y.__name__[0] != 'v':
                        try:
                            cls.reverse_opcodes[y.code] = y
                        except AttributeError:
                            pass
        read = lambda: struct.unpack('>I', f.read(4))[0]
        full_code = struct.unpack('>Q', f.read(8))[0]
        self.code = full_code % (1 << Instruction.code_length)
        self.size = full_code >> Instruction.code_length
        self.type = cls.reverse_opcodes[self.code]
        t = self.type
        name = t.__name__
        try:
            n_args = len(t.arg_format)
            self.var_args = False
        except:
            n_args = read()
            self.var_args = True
        try:
            arg_format = iter(t.arg_format)
        except:
            if name == 'cisc':
                arg_format = itertools.chain(['str'], itertools.repeat('int'))
            else:
                def arg_iter():
                    i = 0
                    while True:
                        try:
                            yield self.args[i].i
                        except AttributeError:
                            yield None
                        i += 1
                arg_format = t.dynamic_arg_format(arg_iter())
        self.args = []
        for i in range(n_args):
            self.args.append(ArgFormats[next(arg_format)](f))

    def __str__(self):
        name = self.type.__name__
        res = name + ' '
        if self.size > 1:
            res = 'v' + res + str(self.size) + ', '
        if self.var_args:
            res += str(len(self.args)) + ', '
        res += ', '.join(str(arg) for arg in self.args)
        return res

    def get_usage(self):
        return self.type.get_usage(self.args)

class VarArgsInstruction(Instruction):
    def has_var_args(self):
        return True

class VectorInstruction(Instruction):
    __slots__ = []
    is_vec = lambda self: True
    vector_index = 0

    def get_code(self):
        return super(VectorInstruction, self).get_code(
            len(self.args[self.vector_index]))

class Ciscable(Instruction):
    def copy(self, size, subs):
        return type(self)(*self.get_new_args(size, subs), copying=True)

class DynFormatInstruction(Instruction):
    __slots__ = []

    @property
    def arg_format(self):
        return self.dynamic_arg_format(iter(self.args))

    @classmethod
    def bases(self, args):
        i = 0
        while True:
            try:
                n = next(args)
            except StopIteration:
                return
            yield i, n
            i += n
            for j in range(n - 1):
                next(args)

###
### Basic arithmetic
###

class AddBase(Instruction):
    __slots__ = []

class SubBase(Instruction):
    __slots__ = []

class MulBase(Instruction):
    __slots__ = []

###
### Basic arithmetic with immediate values
###

class ImmediateBase(Instruction):
    __slots__ = ['op']

class SharedImmediate(ImmediateBase):
    __slots__ = []
    arg_format = ['sw', 's', 'i']

class ClearImmediate(ImmediateBase):
    __slots__ = []
    arg_format = ['cw', 'c', 'i']


###
### Memory access instructions
###

class MemoryInstruction(Instruction):
    __slots__ = ['_protect']
    def __init__(self, *args, **kwargs):
        super(MemoryInstruction, self).__init__(*args, **kwargs)
        self._protect = program._protect_memory

class DirectMemoryInstruction(MemoryInstruction):
    __slots__ = []
    def __init__(self, *args, **kwargs):
        super(DirectMemoryInstruction, self).__init__(*args, **kwargs)

class IndirectMemoryInstruction(MemoryInstruction):
    __slots__ = []

    def get_direct(self, address):
        return self.direct(self.args[0], address, add_to_prog=False)

class ReadMemoryInstruction(MemoryInstruction):
    __slots__ = []

class WriteMemoryInstruction(MemoryInstruction):
    __slots__ = []

class DirectMemoryWriteInstruction(DirectMemoryInstruction, \
                                       WriteMemoryInstruction):
    __slots__ = []
    def __init__(self, *args, **kwargs):
        if not program.curr_tape.singular:
            raise CompilerError('Direct memory writing prevented in threads')
        super(DirectMemoryWriteInstruction, self).__init__(*args, **kwargs)

###
### I/O instructions
###

class DoNotEliminateInstruction(Instruction):
    """ What do you think? """
    __slots__ = []

class IOInstruction(DoNotEliminateInstruction):
    """ Instruction that uses stdin/stdout during runtime. These are linked
    to prevent instruction reordering during optimization. """
    __slots__ = []

    @classmethod
    def str_to_int(cls, s):
        """ Convert a 4 character string to an integer. """
        try:
            s = bytearray(s, 'utf8')
        except:
            pass
        if len(s) > 4:
            raise CompilerError('String longer than 4 characters')
        n = 0
        for c in reversed(s.ljust(4)):
            n <<= 8
            n += c
        return n

class AsymmetricCommunicationInstruction(DoNotEliminateInstruction):
    """ Instructions involving sending from or to only one party. """
    __slots__ = []

class RawInputInstruction(AsymmetricCommunicationInstruction):
    """ Raw input instructions. """
    __slots__ = []

class PublicFileIOInstruction(DoNotEliminateInstruction):
    """ Instruction to reads/writes public information from/to files. """
    __slots__ = []

class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
    """ Input from text file or stdin """
    __slots__ = []

    def add_usage(self, req_node):
        for player in self.get_players():
            req_node.increment((self.field_type, 'input', player), \
                               self.get_size())

###
### Data access instructions
###

class DataInstruction(Instruction):
    __slots__ = []
    field_type = 'modp'

    def add_usage(self, req_node):
        req_node.increment((self.field_type, self.data_type),
                           self.get_size() * self.get_repeat())

    def get_repeat(self):
        return 1

###
### Integer operations
### 

class IntegerInstruction(Instruction):
    """ Base class for integer operations. """
    __slots__ = []
    arg_format = ['ciw', 'ci', 'ci']

class StackInstruction(DoNotEliminateInstruction):
    """ Base class for thread-local stack instructions. """
    __slots__ = []

###
### Clear comparison instructions
###

class UnaryComparisonInstruction(Instruction):
    """ Base class for unary comparisons. """
    __slots__ = []
    arg_format = ['ciw', 'ci']

### 
### Clear shift instructions
### 

class ClearShiftInstruction(ClearImmediate):
    __slots__ = []

    def check_args(self):
        super(ClearShiftInstruction, self).check_args()
        if self.args[2] < 0:
            raise CompilerError('negative shift')

###
### Jumps etc
###

class JumpInstruction(Instruction):
    __slots__ = ['jump_arg']

    def set_relative_jump(self, value):
        if value == -1:
            raise CompilerError('Jump by -1 would cause infinite loop')
        self.args[self.jump_arg] = value

    def get_relative_jump(self):
        return self.args[self.jump_arg]


class CISC(Instruction):
    """
    Base class for a CISC instruction.
    
    Children must implement expand(self) to process the instruction.
    """
    __slots__ = []
    code = None

    def __init__(self, *args):
        self.args = args
        self.check_args()
        self.expand()
    
    def expand(self):
        """ Expand this into a sequence of RISC instructions. """
        raise NotImplementedError('expand method must be implemented')


class InvertInstruction(Instruction):
    __slots__ = []

    def __init__(self, *args, **kwargs):
        if program.options.ring and not self.is_gf2n():
            raise CompilerError('inverse undefined in rings')
        super(InvertInstruction, self).__init__(*args, **kwargs)
