"""
This module contains all instruction types for arithmetic computation
and general control of the virtual machine such as control flow.
The parameter descriptions refer to the instruction arguments in the
right order.
"""
# All base classes, utility functions etc. should go in
# instructions_base.py instead. This is for two reasons:
# 1) Easier generation of documentation
# 2) Ensures that 'from instructions import *' will only import assembly
# instructions and nothing else.
#
# Note: every instruction should have a suitable docstring for
# auto-generation of documentation
import itertools
import operator
import math
from . import tools
from random import randint
from functools import reduce
from Compiler.config import *
from Compiler.exceptions import *
import Compiler.instructions_base as base
# avoid naming collision with input instruction
_python_input = input
###
### Load and store instructions
###
[docs]@base.gf2n
@base.vectorize
class ldi(base.Instruction):
""" Assign (constant) immediate value to clear register (vector).
:param: destination (cint)
:param: value (int)
"""
__slots__ = []
code = base.opcodes['LDI']
arg_format = ['cw','i']
[docs]@base.gf2n
@base.vectorize
class ldsi(base.Instruction):
""" Assign (constant) immediate value to secret register (vector).
:param: destination (sint)
:param: value (int)
"""
__slots__ = []
code = base.opcodes['LDSI']
arg_format = ['sw','i']
[docs]@base.gf2n
@base.vectorize
class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
""" Assign clear memory value(s) to clear register (vector) by
immediate address. The vectorized version starts at the base
address and then iterates the memory address.
:param: destination (cint)
:param: memory address base (int)
"""
__slots__ = []
code = base.opcodes['LDMC']
arg_format = ['cw','long']
[docs]@base.gf2n
@base.vectorize
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
""" Assign secret memory value(s) to secret register (vector) by
immediate address. The vectorized version starts at the base
address and then iterates the memory address.
:param: destination (sint)
:param: memory address base (int)
"""
__slots__ = []
code = base.opcodes['LDMS']
arg_format = ['sw','long']
[docs]@base.gf2n
@base.vectorize
class stmc(base.DirectMemoryWriteInstruction):
""" Assign clear register (vector) to clear memory value(s) by
immediate address. The vectorized version starts at the base
address and then iterates the memory address.
:param: source (cint)
:param: memory address base (int)
"""
__slots__ = []
code = base.opcodes['STMC']
arg_format = ['c','long']
[docs]@base.gf2n
@base.vectorize
class stms(base.DirectMemoryWriteInstruction):
""" Assign secret register (vector) to secret memory value(s) by
immediate address. The vectorized version starts at the base
address and then iterates the memory address.
:param: source (sint)
:param: memory address base (int)
"""
__slots__ = []
code = base.opcodes['STMS']
arg_format = ['s','long']
[docs]@base.vectorize
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
""" Assign clear integer memory value(s) to clear integer register
(vector) by immediate address. The vectorized version starts at
the base address and then iterates the memory address.
:param: destination (regint)
:param: memory address base (int)
"""
__slots__ = []
code = base.opcodes['LDMINT']
arg_format = ['ciw','long']
[docs]@base.vectorize
class stmint(base.DirectMemoryWriteInstruction):
""" Assign clear integer register (vector) to clear integer memory
value(s) by immediate address. The vectorized version starts at
the base address and then iterates the memory address.
:param: source (regint)
:param: memory address base (int)
"""
__slots__ = []
code = base.opcodes['STMINT']
arg_format = ['ci','long']
[docs]@base.vectorize
class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
""" Assign clear memory value(s) to clear register (vector) by
register address. The vectorized version starts at the base
address and then iterates the memory address.
:param: destination (cint)
:param: memory address base (regint)
"""
code = base.opcodes['LDMCI']
arg_format = ['cw','ci']
direct = staticmethod(ldmc)
[docs]@base.vectorize
class ldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
""" Assign secret memory value(s) to secret register (vector) by
register address. The vectorized version starts at the base
address and then iterates the memory address.
:param: destination (sint)
:param: memory address base (regint)
"""
code = base.opcodes['LDMSI']
arg_format = ['sw','ci']
direct = staticmethod(ldms)
[docs]@base.vectorize
class stmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
""" Assign clear register (vector) to clear memory value(s) by
register address. The vectorized version starts at the base
address and then iterates the memory address.
:param: source (cint)
:param: memory address base (regint)
"""
code = base.opcodes['STMCI']
arg_format = ['c','ci']
direct = staticmethod(stmc)
[docs]@base.vectorize
class stmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
""" Assign secret register (vector) to secret memory value(s) by
register address. The vectorized version starts at the base
address and then iterates the memory address.
:param: source (sint)
:param: memory address base (regint)
"""
code = base.opcodes['STMSI']
arg_format = ['s','ci']
direct = staticmethod(stms)
[docs]@base.vectorize
class ldminti(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
""" Assign clear integer memory value(s) to clear integer register
(vector) by register address. The vectorized version starts at the
base address and then iterates the memory address.
:param: destination (regint)
:param: memory address base (regint)
"""
code = base.opcodes['LDMINTI']
arg_format = ['ciw','ci']
direct = staticmethod(ldmint)
[docs]@base.vectorize
class stminti(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
""" Assign clear integer register (vector) to clear integer memory
value(s) by register address. The vectorized version starts at the
base address and then iterates the memory address.
:param: source (regint)
:param: memory address base (regint)
"""
code = base.opcodes['STMINTI']
arg_format = ['ci','ci']
direct = staticmethod(stmint)
@base.vectorize
class gldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
code = base.opcodes['LDMCI'] + 0x100
arg_format = ['cgw','ci']
direct = staticmethod(gldmc)
@base.vectorize
class gldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """
code = base.opcodes['LDMSI'] + 0x100
arg_format = ['sgw','ci']
direct = staticmethod(gldms)
@base.vectorize
class gstmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
code = base.opcodes['STMCI'] + 0x100
arg_format = ['cg','ci']
direct = staticmethod(gstmc)
@base.vectorize
class gstmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
code = base.opcodes['STMSI'] + 0x100
arg_format = ['sg','ci']
direct = staticmethod(gstms)
[docs]@base.gf2n
@base.vectorize
class movc(base.Instruction):
""" Copy clear register (vector).
:param: destination (cint)
:param: source (cint)
"""
__slots__ = []
code = base.opcodes['MOVC']
arg_format = ['cw','c']
[docs]@base.gf2n
@base.vectorize
class movs(base.Instruction):
""" Copy secret register (vector).
:param: destination (cint)
:param: source (cint)
"""
__slots__ = []
code = base.opcodes['MOVS']
arg_format = ['sw','s']
[docs]@base.vectorize
class movint(base.Instruction):
""" Copy clear integer register (vector).
:param: destination (regint)
:param: source (regint)
"""
__slots__ = []
code = base.opcodes['MOVINT']
arg_format = ['ciw','ci']
[docs]@base.vectorize
class pushint(base.StackInstruction):
""" Pushes clear integer register to the thread-local stack.
Considered obsolete.
:param: source (regint)
"""
code = base.opcodes['PUSHINT']
arg_format = ['ci']
[docs]@base.vectorize
class popint(base.StackInstruction):
""" Pops from the thread-local stack to clear integer register.
Considered obsolete.
:param: destination (regint)
"""
code = base.opcodes['POPINT']
arg_format = ['ciw']
###
### Machine
###
[docs]@base.vectorize
class ldtn(base.Instruction):
""" Store the number of the current thread in clear integer register.
:param: destination (regint)
"""
code = base.opcodes['LDTN']
arg_format = ['ciw']
[docs]@base.vectorize
class ldarg(base.Instruction):
""" Store the argument passed to the current thread in clear integer
register.
:param: destination (regint)
"""
code = base.opcodes['LDARG']
arg_format = ['ciw']
[docs]@base.vectorize
class starg(base.Instruction):
""" Copy clear integer register to the thread argument.
:param: source (regint)
"""
code = base.opcodes['STARG']
arg_format = ['ci']
[docs]@base.vectorize
class cmdlinearg(base.Instruction):
""" Load command-line argument.
:param: dest (regint)
:param: index (regint)
"""
code = base.opcodes['CMDLINEARG']
arg_format = ['ciw','ci']
[docs]@base.gf2n
class reqbl(base.Instruction):
""" Requirement on computation modulus. Minimal bit length of prime if
positive, minus exact bit length of power of two if negative.
:param: requirement (int)
"""
code = base.opcodes['REQBL']
arg_format = ['int']
[docs]class active(base.Instruction):
""" Indicate whether program is compatible with malicious-security
protocols.
:param: 0 for no, 1 for yes
"""
code = base.opcodes['ACTIVE']
arg_format = ['int']
[docs]class time(base.IOInstruction):
""" Output time since start of computation. """
code = base.opcodes['TIME']
arg_format = []
[docs]class start(base.Instruction):
""" Start timer.
:param: timer number (int)
"""
code = base.opcodes['START']
arg_format = ['i']
[docs]class stop(base.Instruction):
""" Stop timer.
:param: timer number (int)
"""
code = base.opcodes['STOP']
arg_format = ['i']
[docs]class use(base.Instruction):
""" Offline data usage. Necessary to avoid reusage while using
preprocessing from files. Also used to multithreading for expensive
preprocessing.
:param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
:param: type (0: triple, 1: square, 2: bit, 3: inverse, 6: daBit)
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE']
arg_format = ['int','int','long']
@classmethod
def get_usage(cls, args):
from .program import field_types, data_types
from .util import find_in_dict
return {(find_in_dict(field_types, args[0].i),
find_in_dict(data_types, args[1].i)):
args[2].i}
[docs]class use_inp(base.Instruction):
""" Input usage. Necessary to avoid reusage while using
preprocessing from files.
:param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
:param: input player (int)
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE_INP']
arg_format = ['int','int','long']
@classmethod
def get_usage(cls, args):
from .program import field_types, data_types
from .util import find_in_dict
return {(find_in_dict(field_types, args[0].i), 'input', args[1].i):
args[2].i}
[docs]class use_edabit(base.Instruction):
""" edaBit usage. Necessary to avoid reusage while using
preprocessing from files. Also used to multithreading for expensive
preprocessing.
:param: loose/strict (0/1)
:param: length (int)
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE_EDABIT']
arg_format = ['int','int','long']
@classmethod
def get_usage(cls, args):
return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i}
[docs]class use_matmul(base.Instruction):
""" Matrix multiplication usage. Used for multithreading of
preprocessing.
:param: number of left-hand rows (int)
:param: number of left-hand columns/right-hand rows (int)
:param: number of right-hand columns (int)
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE_MATMUL']
arg_format = ['int','int','int','long']
@classmethod
def get_usage(cls, args):
return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i}
[docs]class run_tape(base.Instruction):
""" Start tape/bytecode file in another thread.
:param: number of arguments to follow (multiple of three)
:param: virtual machine thread number (int)
:param: tape number (int)
:param: tape argument (int)
:param: (repeat the last three)...
"""
code = base.opcodes['RUN_TAPE']
arg_format = tools.cycle(['int','int','int'])
[docs]class join_tape(base.Instruction):
""" Join thread.
:param: virtual machine thread number (int)
"""
code = base.opcodes['JOIN_TAPE']
arg_format = ['int']
[docs]class call_tape(base.DoNotEliminateInstruction):
""" Start tape/bytecode file in same thread. Arguments/return values
starting from :py:obj:`direction` are optional.
:param: tape number (int)
:param: arg (regint)
:param: direction (0 for argument, 1 for return value)
:param: register type (see :py:obj:`vm_types`)
:param: register size (int)
:param: destination register
:param: source register
:param: (repeat from direction)
"""
code = base.opcodes['CALL_TAPE']
arg_format = tools.chain(['int', 'ci'],
tools.cycle(['int','int','int','*w','*']))
@staticmethod
def type_check(reg, type_id):
assert base.vm_types[reg.reg_type] == type_id
def __init__(self, *args, **kwargs):
super(call_tape, self).__init__(*args, **kwargs)
for i in range(2, len(args), 5):
for reg in args[i + 3:i + 5]:
self.type_check(reg, args[i + 1])
assert reg.size == args[i + 2]
assert args[i] in (0, 1)
assert args[i + 4 - args[i]].program == program.curr_tape
assert args[i + 3 + args[i]].program == program.tapes[args[0]]
[docs] def get_def(self):
# hide registers from called tape
for i in range(2, len(self.args), 5):
if self.args[i]:
yield self.args[i + 3]
[docs] def get_used(self):
# hide registers from called tape
yield self.args[1]
for i in range(2, len(self.args), 5):
if not self.args[i]:
yield self.args[i + 4]
def add_usage(self, req_node):
req_node.num += program.tapes[self.args[0]].req_tree.aggregate()
[docs]class call_arg(base.DoNotEliminateInstruction, base.VectorInstruction):
""" Pseudo instruction for arguments in connection with
:py:class:`call_tape`.
:param: destination (register)
:param: register type (see :py:obj:`vm_types`)
"""
code = base.opcodes['CALL_ARG']
arg_format = ['*w','int']
def __init__(self, *args, **kwargs):
super(call_arg, self).__init__(*args, **kwargs)
for i in range(0, len(args), 2):
call_tape.type_check(args[i], args[i + 1])
[docs]class crash(base.IOInstruction):
""" Crash runtime if the value in the register is not zero.
:param: Crash condition (regint)"""
code = base.opcodes['CRASH']
arg_format = ['ci']
class start_grind(base.IOInstruction):
code = base.opcodes['STARTGRIND']
arg_format = []
class stop_grind(base.IOInstruction):
code = base.opcodes['STOPGRIND']
arg_format = []
[docs]@base.gf2n
class use_prep(base.Instruction):
""" Custom preprocessed data usage.
:param: tag (16 bytes / 4 units, cut off at first zero byte)
:param: number of items to use (int, -1 for unknown)
"""
code = base.opcodes['USE_PREP']
arg_format = ['str','long']
@classmethod
def get_usage(cls, args):
return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp',
args[0].str): args[1].i}
[docs]class nplayers(base.Instruction):
""" Store number of players in clear integer register.
:param: destination (regint)
"""
code = base.opcodes['NPLAYERS']
arg_format = ['ciw']
[docs]class threshold(base.Instruction):
""" Store maximal number of corrupt players in clear integer register.
:param: destination (regint)
"""
code = base.opcodes['THRESHOLD']
arg_format = ['ciw']
[docs]class playerid(base.Instruction):
""" Store current player number in clear integer register.
:param: destination (regint)
"""
code = base.opcodes['PLAYERID']
arg_format = ['ciw']
###
### Basic arithmetic
###
[docs]@base.gf2n
@base.vectorize
class addc(base.AddBase):
""" Clear addition.
:param: result (cint)
:param: summand (cint)
:param: summand (cint)
"""
__slots__ = []
code = base.opcodes['ADDC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class adds(base.AddBase):
""" Secret addition.
:param: result (sint)
:param: summand (sint)
:param: summand (sint)
"""
__slots__ = []
code = base.opcodes['ADDS']
arg_format = ['sw','s','s']
[docs]@base.gf2n
@base.vectorize
class addm(base.AddBase):
""" Mixed addition.
:param: result (sint)
:param: summand (sint)
:param: summand (cint)
"""
__slots__ = []
code = base.opcodes['ADDM']
arg_format = ['sw','s','c']
[docs]@base.gf2n
@base.vectorize
class subc(base.SubBase):
""" Clear subtraction.
:param: result (cint)
:param: first operand (cint)
:param: second operand (cint)
"""
__slots__ = []
code = base.opcodes['SUBC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class subs(base.SubBase):
""" Secret subtraction.
:param: result (sint)
:param: first operand (sint)
:param: second operand (sint)
"""
__slots__ = []
code = base.opcodes['SUBS']
arg_format = ['sw','s','s']
[docs]@base.gf2n
@base.vectorize
class subml(base.SubBase):
""" Subtract clear from secret value.
:param: result (sint)
:param: first operand (sint)
:param: second operand (cint)
"""
__slots__ = []
code = base.opcodes['SUBML']
arg_format = ['sw','s','c']
[docs]@base.gf2n
@base.vectorize
class submr(base.SubBase):
""" Subtract secret from clear value.
:param: result (sint)
:param: first operand (cint)
:param: second operand (sint)
"""
__slots__ = []
code = base.opcodes['SUBMR']
arg_format = ['sw','c','s']
[docs]@base.vectorize
class prefixsums(base.Instruction):
""" Prefix sum.
:param: result (sint)
:param: input (sint)
"""
__slots__ = []
code = base.opcodes['PREFIXSUMS']
arg_format = ['sw','s']
[docs]class picks(base.VectorInstruction):
""" Extract part of vector.
:param: result (sint)
:param: input (sint)
:param: start offset (int)
:param: step
"""
__slots__ = []
code = base.opcodes['PICKS']
arg_format = ['sw','s','int','int']
def __init__(self, *args):
super(picks, self).__init__(*args)
assert 0 <= args[2] < len(args[1])
assert 0 <= args[2] + args[3] * (len(args[0]) - 1) < len(args[1])
[docs]class concats(base.VectorInstruction):
""" Concatenate vectors.
:param: result (sint)
:param: start offset (int)
:param: input (sint)
:param: (repeat from offset)...
"""
__slots__ = []
code = base.opcodes['CONCATS']
arg_format = tools.chain(['sw'], tools.cycle(['int','s']))
def __init__(self, *args):
super(concats, self).__init__(*args)
assert len(args) % 2 == 1
assert len(args[0]) == sum(args[1::2])
for i in range(1, len(args), 2):
assert args[i] == len(args[i + 1])
[docs]class zips(base.Instruction):
""" Zip vectors.
:param: result (sint)
:param: operand (sint)
:param: operand (sint)
"""
__slots__ = []
code = base.opcodes['ZIPS']
arg_format = ['sw','s','s']
is_vec = lambda self: True
def __init__(self, *args):
super(zips, self).__init__(*args)
assert len(args[0]) == len(args[1]) + len(args[2])
assert len(args[1]) == len(args[2])
def get_code(self):
return super(zips, self).get_code(len(self.args[1]))
[docs]@base.gf2n
@base.vectorize
class mulc(base.MulBase):
""" Clear multiplication.
:param: result (cint)
:param: factor (cint)
:param: factor (cint)
"""
__slots__ = []
code = base.opcodes['MULC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class mulm(base.MulBase):
""" Multiply secret and clear value.
:param: result (sint)
:param: factor (sint)
:param: factor (cint)
"""
__slots__ = []
code = base.opcodes['MULM']
arg_format = ['sw','s','c']
[docs]@base.gf2n
@base.vectorize
class divc(base.InvertInstruction):
""" Clear division.
:param: result (cint)
:param: dividend (cint)
:param: divisor (cint)
"""
__slots__ = []
code = base.opcodes['DIVC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class floordivc(base.Instruction):
""" Clear integer floor division.
:param: result (cint)
:param: dividend (cint)
:param: divisor (cint)
"""
__slots__ = []
code = base.opcodes['FLOORDIVC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class modc(base.Instruction):
""" Clear modular reduction.
:param: result (cint)
:param: dividend (cint)
:param: divisor (cint)
"""
__slots__ = []
code = base.opcodes['MODC']
arg_format = ['cw','c','c']
[docs]@base.vectorize
class inv2m(base.InvertInstruction):
""" Inverse of power of two modulo prime (the computation modulus).
:param: result (cint)
:param: exponent (int)
"""
__slots__ = []
code = base.opcodes['INV2M']
arg_format = ['cw','int']
[docs]@base.vectorize
class legendrec(base.Instruction):
""" Clear Legendre symbol computation (a/p) over prime p
(the computation modulus).
:param: result (cint)
:param: a (int)
"""
__slots__ = []
code = base.opcodes['LEGENDREC']
arg_format = ['cw','c']
[docs]@base.vectorize
class digestc(base.Instruction):
""" Clear truncated hash computation.
:param: result (cint)
:param: input (cint)
:param: byte length of hash value used (int)
"""
__slots__ = []
code = base.opcodes['DIGESTC']
arg_format = ['cw','c','int']
###
### Bitwise operations
###
[docs]@base.gf2n
@base.vectorize
class andc(base.Instruction):
""" Logical AND of clear (vector) registers.
:param: result (cint)
:param: operand (cint)
:param: operand (cint)
"""
__slots__ = []
code = base.opcodes['ANDC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class orc(base.Instruction):
""" Logical OR of clear (vector) registers.
:param: result (cint)
:param: operand (cint)
:param: operand (cint)
"""
__slots__ = []
code = base.opcodes['ORC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class xorc(base.Instruction):
""" Logical XOR of clear (vector) registers.
:param: result (cint)
:param: operand (cint)
:param: operand (cint)
"""
__slots__ = []
code = base.opcodes['XORC']
arg_format = ['cw','c','c']
[docs]@base.vectorize
class notc(base.Instruction):
""" Clear logical NOT of a constant number of bits of clear
(vector) register.
:param: result (cint)
:param: operand (cint)
:param: bit length (int)
"""
__slots__ = []
code = base.opcodes['NOTC']
arg_format = ['cw','c', 'int']
@base.vectorize
class gnotc(base.Instruction):
r""" Clear logical NOT $cg_i = \lnot cg_j$ """
__slots__ = []
code = (1 << 8) + base.opcodes['NOTC']
arg_format = ['cgw','cg']
def is_gf2n(self):
return True
@base.vectorize
class gbitdec(base.Instruction):
r""" Store every $n$-th bit of $cg_i$ in $cg_j, \dots$. """
__slots__ = []
code = base.opcodes['GBITDEC']
arg_format = tools.chain(['cg', 'int'], itertools.repeat('cgw'))
def is_g2fn(self):
return True
def has_var_args(self):
return True
@base.vectorize
class gbitcom(base.Instruction):
r""" Store the bits $cg_j, \dots$ as every $n$-th bit of $cg_i$. """
__slots__ = []
code = base.opcodes['GBITCOM']
arg_format = tools.chain(['cgw', 'int'], itertools.repeat('cg'))
def is_g2fn(self):
return True
def has_var_args(self):
return True
###
### Arithmetic with immediate values
###
[docs]@base.gf2n
@base.vectorize
class addci(base.ClearImmediate):
""" Addition of clear register (vector) and (constant) immediate value.
:param: result (cint)
:param: summand (cint)
:param: summand (int)
"""
__slots__ = []
code = base.opcodes['ADDCI']
op = '__add__'
[docs]@base.gf2n
@base.vectorize
class addsi(base.SharedImmediate):
""" Addition of secret register (vector) and (constant) immediate value.
:param: result (cint)
:param: summand (cint)
:param: summand (int)
"""
__slots__ = []
code = base.opcodes['ADDSI']
op = '__add__'
[docs]@base.gf2n
@base.vectorize
class subci(base.ClearImmediate):
""" Subtraction of (constant) immediate value from clear register (vector).
:param: result (cint)
:param: first operand (cint)
:param: second operand (int)
"""
__slots__ = []
code = base.opcodes['SUBCI']
op = '__sub__'
[docs]@base.gf2n
@base.vectorize
class subsi(base.SharedImmediate):
""" Subtraction of (constant) immediate value from secret
register (vector).
:param: result (sint)
:param: first operand (sint)
:param: second operand (int)
"""
__slots__ = []
code = base.opcodes['SUBSI']
op = '__sub__'
[docs]@base.gf2n
@base.vectorize
class subcfi(base.ClearImmediate):
""" Subtraction of clear register (vector) from (constant)
immediate value.
:param: result (cint)
:param: first operand (int)
:param: second operand (cint)
"""
__slots__ = []
code = base.opcodes['SUBCFI']
op = '__rsub__'
[docs]@base.gf2n
@base.vectorize
class subsfi(base.SharedImmediate):
""" Subtraction of secret register (vector) from (constant)
immediate value.
:param: result (sint)
:param: first operand (int)
:param: second operand (sint)
"""
__slots__ = []
code = base.opcodes['SUBSFI']
op = '__rsub__'
[docs]@base.gf2n
@base.vectorize
class mulci(base.ClearImmediate):
""" Multiplication of clear register (vector) and (constant)
immediate value.
:param: result (cint)
:param: factor (cint)
:param: factor (int)
"""
__slots__ = []
code = base.opcodes['MULCI']
op = '__mul__'
[docs]@base.gf2n
@base.vectorize
class mulsi(base.SharedImmediate):
""" Multiplication of secret register (vector) and (constant)
immediate value.
:param: result (sint)
:param: factor (sint)
:param: factor (int)
"""
__slots__ = []
code = base.opcodes['MULSI']
op = '__mul__'
[docs]@base.gf2n
@base.vectorize
class divci(base.InvertInstruction, base.ClearImmediate):
""" Division of secret register (vector) and (constant) immediate value.
:param: result (cint)
:param: dividend (cint)
:param: divisor (int)
"""
__slots__ = []
code = base.opcodes['DIVCI']
[docs]@base.gf2n
@base.vectorize
class modci(base.ClearImmediate):
""" Modular reduction of clear register (vector) and (constant)
immediate value.
:param: result (cint)
:param: dividend (cint)
:param: divisor (int)
"""
__slots__ = []
code = base.opcodes['MODCI']
op = '__mod__'
[docs]@base.gf2n
@base.vectorize
class andci(base.ClearImmediate):
""" Logical AND of clear register (vector) and (constant)
immediate value.
:param: result (cint)
:param: operand (cint)
:param: operand (int)
"""
__slots__ = []
code = base.opcodes['ANDCI']
op = '__and__'
[docs]@base.gf2n
@base.vectorize
class xorci(base.ClearImmediate):
""" Logical XOR of clear register (vector) and (constant)
immediate value.
:param: result (cint)
:param: operand (cint)
:param: operand (int)
"""
__slots__ = []
code = base.opcodes['XORCI']
op = '__xor__'
[docs]@base.gf2n
@base.vectorize
class orci(base.ClearImmediate):
""" Logical OR of clear register (vector) and (constant)
immediate value.
:param: result (cint)
:param: operand (cint)
:param: operand (int)
"""
__slots__ = []
code = base.opcodes['ORCI']
op = '__or__'
###
### Shift instructions
###
[docs]@base.gf2n
@base.vectorize
class shlc(base.Instruction):
""" Bitwise left shift of clear register (vector).
:param: result (cint)
:param: first operand (cint)
:param: second operand (cint)
"""
__slots__ = []
code = base.opcodes['SHLC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class shrc(base.Instruction):
""" Bitwise right shift of clear register (vector).
:param: result (cint)
:param: first operand (cint)
:param: second operand (cint)
"""
__slots__ = []
code = base.opcodes['SHRC']
arg_format = ['cw','c','c']
[docs]@base.gf2n
@base.vectorize
class shlci(base.ClearShiftInstruction):
""" Bitwise left shift of clear register (vector) by (constant)
immediate value.
:param: result (cint)
:param: first operand (cint)
:param: second operand (int)
"""
__slots__ = []
code = base.opcodes['SHLCI']
op = '__lshift__'
[docs]@base.gf2n
@base.vectorize
class shrci(base.ClearShiftInstruction):
""" Bitwise right shift of clear register (vector) by (constant)
immediate value.
:param: result (cint)
:param: first operand (cint)
:param: second operand (int)
"""
__slots__ = []
code = base.opcodes['SHRCI']
op = '__rshift__'
[docs]@base.gf2n
@base.vectorize
class shrsi(base.ClearShiftInstruction):
""" Bitwise right shift of secret register (vector) by (constant)
immediate value. This only makes sense in connection with
protocols allowing local share conversion (i.e., based on additive
secret sharing modulo a power of two). Moreover, the result is not
a secret sharing of the right shift of the secret value but needs
to be corrected using the overflow. This is explained by `Dalskov
et al. <https://eprint.iacr.org/2020/1330>`_ in the appendix.
:param: result (sint)
:param: first operand (sint)
:param: second operand (int)
"""
__slots__ = []
code = base.opcodes['SHRSI']
arg_format = ['sw','s','i']
###
### Data access instructions
###
[docs]@base.gf2n
@base.vectorize
class triple(base.DataInstruction):
""" Store fresh random triple(s) in secret register (vectors).
:param: factor (sint)
:param: factor (sint)
:param: product (sint)
"""
__slots__ = []
code = base.opcodes['TRIPLE']
arg_format = ['sw','sw','sw']
data_type = 'triple'
@base.vectorize
class gbittriple(base.DataInstruction):
r""" Load secret variables $s_i$, $s_j$ and $s_k$
with the next GF(2) multiplication triple. """
__slots__ = []
code = base.opcodes['GBITTRIPLE']
arg_format = ['sgw','sgw','sgw']
data_type = 'bittriple'
field_type = 'gf2n'
def is_gf2n(self):
return True
@base.vectorize
class gbitgf2ntriple(base.DataInstruction):
r""" Load secret variables $s_i$, $s_j$ and $s_k$
with the next GF(2) and GF(2^n) multiplication triple. """
code = base.opcodes['GBITGF2NTRIPLE']
arg_format = ['sgw','sgw','sgw']
data_type = 'bitgf2ntriple'
field_type = 'gf2n'
def is_gf2n(self):
return True
[docs]@base.gf2n
@base.vectorize
class bit(base.DataInstruction):
""" Store fresh random triple(s) in secret register (vectors).
:param: destination (sint)
"""
__slots__ = []
code = base.opcodes['BIT']
arg_format = ['sw']
data_type = 'bit'
[docs]@base.vectorize
class dabit(base.DataInstruction):
""" Store fresh random daBit(s) in secret register (vectors).
:param: arithmetic part (sint)
:param: binary part (sbit)
"""
__slots__ = []
code = base.opcodes['DABIT']
arg_format = ['sw', 'sbw']
field_type = 'modp'
data_type = 'dabit'
[docs]@base.vectorize
class edabit(base.Instruction):
""" Store fresh random loose edaBit(s) in secret register (vectors).
The length is the first argument minus one.
:param: number of arguments to follow / number of bits plus two (int)
:param: arithmetic (sint)
:param: binary (sbit)
:param: (binary)...
"""
__slots__ = []
code = base.opcodes['EDABIT']
arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment(('edabit', len(self.args) - 1), self.get_size())
[docs]@base.vectorize
class sedabit(base.Instruction):
""" Store fresh random strict edaBit(s) in secret register (vectors).
The length is the first argument minus one.
:param: number of arguments to follow / number of bits plus two (int)
:param: arithmetic (sint)
:param: binary (sbit)
:param: (binary)...
"""
__slots__ = []
code = base.opcodes['SEDABIT']
arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment(('sedabit', len(self.args) - 1), self.get_size())
[docs]@base.vectorize
class randoms(base.Instruction):
""" Store fresh length-restricted random shares(s) in secret register
(vectors). This is only implemented for protocols that also implement
local share conversion with :py:obj:`~Compiler.GC.instructions.split`.
:param: destination (sint)
:param: length (int)
"""
__slots__ = []
code = base.opcodes['RANDOMS']
arg_format = ['sw','int']
field_type = 'modp'
[docs]@base.vectorize
class randomfulls(base.DataInstruction):
""" Store share(s) of a fresh secret random element in secret
register (vectors).
:param: destination (sint)
"""
__slots__ = []
code = base.opcodes['RANDOMFULLS']
arg_format = ['sw']
field_type = 'modp'
data_type = 'random'
def get_repeat(self):
return len(self.args)
[docs]@base.gf2n
@base.vectorize
class square(base.DataInstruction):
""" Store fresh random square(s) in secret register (vectors).
:param: value (sint)
:param: square (sint)
"""
__slots__ = []
code = base.opcodes['SQUARE']
arg_format = ['sw','sw']
data_type = 'square'
[docs]@base.gf2n
@base.vectorize
class inverse(base.DataInstruction):
""" Store fresh random inverse(s) in secret register (vectors).
:param: value (sint)
:param: inverse (sint)
"""
__slots__ = []
code = base.opcodes['INV']
arg_format = ['sw','sw']
data_type = 'inverse'
def __init__(self, *args, **kwargs):
if program.options.ring and not self.is_gf2n():
raise CompilerError('random inverse in ring not implemented')
base.DataInstruction.__init__(self, *args, **kwargs)
[docs]@base.gf2n
@base.vectorize
class prep(base.Instruction):
""" Store custom preprocessed data in secret register (vectors).
:param: number of arguments to follow (int)
:param: tag (16 bytes / 4 units, cut off at first zero byte)
:param: destination (sint)
:param: (repeat destination)...
"""
__slots__ = []
code = base.opcodes['PREP']
arg_format = tools.chain(['str'], itertools.repeat('sw'))
gf2n_arg_format = tools.chain(['str'], itertools.repeat('sgw'))
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, self.args[0]), self.get_size())
def has_var_args(self):
return True
###
### I/O
###
@base.gf2n
@base.vectorize
class asm_input(base.TextInputInstruction):
r""" Receive input from player $p$ and put in register $s_i$. """
__slots__ = []
code = base.opcodes['INPUT']
arg_format = tools.cycle(['sw', 'p'])
field_type = 'modp'
def get_players(self):
return self.args[1::2]
class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction):
__slots__ = []
field_type = 'modp'
# the following has to match TYPE: (N_DEST, N_PARAM)
types = {
0: (1, 0),
1: (1, 1),
2: (4, 1)
}
type_ids = {
'int': 0,
'fix': 1,
'float': 2
}
def __init__(self, name, *args):
type_id = self.type_ids[name]
super(inputmixed_base, self).__init__(type_id, *args)
@classmethod
def dynamic_arg_format(self, args):
yield 'int'
for i, t in self.bases(iter(args)):
for j in range(self.types[t][0]):
yield 'sw'
for j in range(self.types[t][1]):
yield 'int'
yield self.player_arg_type
yield 'int'
@classmethod
def bases(self, args):
i = 0
while True:
try:
t = next(args)
except StopIteration:
return
yield i, t
n = sum(self.types[t])
i += n + 2
for j in range(n + 1):
next(args)
[docs]class personal_base(base.Instruction, base.Mergeable):
__slots__ = []
field_type = 'modp'
def __init__(self, *args):
super(personal_base, self).__init__(*args)
for i in range(0, len(args), 4):
assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i]
def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])
[docs]class privateoutput(personal_base, base.DataInstruction):
""" Private output to cint.
:param: vector size (int)
:param: player (int)
:param: destination (cint)
:param: source (sint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['PRIVATEOUTPUT']
arg_format = tools.cycle(['int','p','cw','s'])
data_type = 'open'
def add_usage(self, req_node):
personal_base.add_usage(self, req_node)
base.DataInstruction.add_usage(self, req_node)
def get_repeat(self):
return sum(self.args[::4])
[docs]class sendpersonal(base.Instruction, base.Mergeable):
""" Private input from cint.
:param: vector size (int)
:param: destination player (int)
:param: destination (cint)
:param: source player (int)
:param: source (cint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['SENDPERSONAL']
arg_format = tools.cycle(['int','p','cw','p','c'])
def __init__(self, *args):
super(sendpersonal, self).__init__(*args)
for i in range(0, len(args), 5):
assert args[i + 2].size == args[i]
assert args[i + 4].size == args[i]
[docs]@base.gf2n
@base.vectorize
class print_reg(base.IOInstruction):
""" Debugging output of clear register (vector).
:param: source (cint)
:param: comment (4 bytes / 1 unit)
"""
__slots__ = []
code = base.opcodes['PRINTREG']
arg_format = ['c','i']
def __init__(self, reg, comment=''):
super(print_reg_class, self).__init__(reg, self.str_to_int(comment))
[docs]@base.gf2n
@base.vectorize
class print_reg_plain(base.IOInstruction):
""" Output clear register.
:param: source (cint)
"""
__slots__ = []
code = base.opcodes['PRINTREGPLAIN']
arg_format = ['c']
[docs]@base.gf2n
class print_reg_plains(base.IOInstruction):
""" Output secret register.
:param: source (sint)
"""
__slots__ = []
code = base.opcodes['PRINTREGPLAINS']
arg_format = ['s']
[docs]class cond_print_plain(base.IOInstruction):
""" Conditionally output clear register (with precision).
Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.
:param: condition (cint, no output if zero)
:param: source (cint)
:param: precision (cint)
"""
code = base.opcodes['CONDPRINTPLAIN']
arg_format = ['c', 'c', 'c']
def __init__(self, *args, **kwargs):
base.Instruction.__init__(self, *args, **kwargs)
self.size = args[1].size
args[2].set_size(self.size)
def get_code(self):
return base.Instruction.get_code(self, self.size)
[docs]class print_int(base.IOInstruction):
""" Output clear integer register.
:param: source (regint)
"""
__slots__ = []
code = base.opcodes['PRINTINT']
arg_format = ['ci']
[docs]@base.vectorize
class print_float_plain(base.IOInstruction):
""" Output floating-number from clear registers.
:param: significand (cint)
:param: exponent (cint)
:param: zero bit (cint, zero output if bit is one)
:param: sign bit (cint, negative output if bit is one)
:param: NaN (cint, regular number if zero)
"""
__slots__ = []
code = base.opcodes['PRINTFLOATPLAIN']
arg_format = ['c', 'c', 'c', 'c', 'c']
[docs]class print_float_prec(base.IOInstruction):
""" Set number of digits after decimal point for
:py:obj:`~Compiler.instructions.print_float_plain`.
:param: number of digits (int)
"""
__slots__ = []
code = base.opcodes['PRINTFLOATPREC']
arg_format = ['int']
[docs]class print_char(base.IOInstruction):
""" Output a single byte.
:param: byte (int)
"""
code = base.opcodes['PRINTCHR']
arg_format = ['int']
def __init__(self, ch):
super(print_char, self).__init__(ch)
[docs]class print_char4(base.IOInstruction):
""" Output four bytes.
:param: four bytes (int)
"""
code = base.opcodes['PRINTSTR']
arg_format = ['int']
def __init__(self, val):
super(print_char4, self).__init__(self.str_to_int(val))
[docs]class cond_print_str(base.IOInstruction):
""" Conditionally output four bytes.
:param: condition (cint, no output if zero)
:param: four bytes (int)
"""
code = base.opcodes['CONDPRINTSTR']
arg_format = ['c', 'int']
def __init__(self, cond, val):
super(cond_print_str, self).__init__(cond, self.str_to_int(val))
[docs]class readsocketc(base.IOInstruction):
""" Read a variable number of clear values in internal representation
from socket for a specified client id and store them in clear registers.
:param: number of arguments to follow / number of inputs minus one (int)
:param: client id (regint)
:param: vector size (int)
:param: destination (cint)
:param: (repeat destination)...
"""
__slots__ = []
code = base.opcodes['READSOCKETC']
arg_format = tools.chain(['ci','int'], itertools.repeat('cw'))
def has_var_args(self):
return True
[docs]class readsockets(base.IOInstruction):
""" Read a variable number of secret shares (potentially with MAC)
from a socket for a client id and store them in registers. If the
protocol uses MACs, the client should be different for every party.
:param: client id (regint)
:param: vector size (int)
:param: source (sint)
:param: (repeat source)...
"""
__slots__ = []
code = base.opcodes['READSOCKETS']
arg_format = tools.chain(['ci','int'], itertools.repeat('sw'))
def has_var_args(self):
return True
[docs]class readsocketint(base.IOInstruction):
""" Read a variable number of 32-bit integers from socket for a
specified client id and store them in clear integer registers.
:param: number of arguments to follow / number of inputs minus one (int)
:param: client id (regint)
:param: vector size (int)
:param: destination (regint)
:param: (repeat destination)...
"""
__slots__ = []
code = base.opcodes['READSOCKETINT']
arg_format = tools.chain(['ci','int'], itertools.repeat('ciw'))
def has_var_args(self):
return True
class writesocketc(base.IOInstruction):
"""
Write a variable number of clear GF(p) values from registers into socket
for a specified client id, message_type
"""
__slots__ = []
code = base.opcodes['WRITESOCKETC']
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('c'))
def has_var_args(self):
return True
[docs]class writesockets(base.IOInstruction):
""" Write a variable number of secret shares (potentially with MAC)
from registers into a socket for a specified client id. If the
protocol uses MACs, the client should be different for every party.
:param: number of arguments to follow
:param: client id (regint)
:param: message type (must be 0)
:param: vector size (int)
:param: source (sint)
:param: (repeat source)...
"""
__slots__ = []
code = base.opcodes['WRITESOCKETS']
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s'))
def has_var_args(self):
return True
[docs]class writesocketshare(base.IOInstruction):
""" Write a variable number of shares (without MACs) from secret
registers into socket for a specified client id.
:param: client id (regint)
:param: message type (must be 0)
:param: vector size (int)
:param: source (sint)
:param: (repeat source)...
"""
__slots__ = []
code = base.opcodes['WRITESOCKETSHARE']
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s'))
def has_var_args(self):
return True
class writesocketint(base.IOInstruction):
"""
Write a variable number of 32-bit ints from registers into socket
for a specified client id, message_type
"""
__slots__ = []
code = base.opcodes['WRITESOCKETINT']
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('ci'))
def has_var_args(self):
return True
[docs]class listen(base.IOInstruction):
""" Open a server socket on a party-specific port number and listen for
client connections (non-blocking).
:param: port number (regint)
"""
__slots__ = []
code = base.opcodes['LISTEN']
arg_format = ['ci']
[docs]class acceptclientconnection(base.IOInstruction):
""" Wait for a connection at the given port and write socket handle
to clear integer register.
:param: client id destination (regint)
:param: port number (regint)
"""
__slots__ = []
code = base.opcodes['ACCEPTCLIENTCONNECTION']
arg_format = ['ciw', 'ci']
[docs]class initclientconnection(base.IOInstruction):
""" Initialize connection.
:param: client id destination (regint)
:param: port number (regint)
:param: my client id (regint)
:param: hostname (variable string)
"""
__slots__ = []
code = base.opcodes['INITCLIENTCONNECTION']
arg_format = ['ciw', 'ci', 'ci', 'varstr']
[docs]class closeclientconnection(base.IOInstruction):
""" Close connection to client.
:param: client id (regint)
"""
__slots__ = []
code = base.opcodes['CLOSECLIENTCONNECTION']
arg_format = ['ci']
[docs]class writesharestofile(base.IOInstruction):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end).
:param: number of arguments to follow / number of shares plus one (int)
:param: position (regint, -1 for appending)
:param: source (sint)
:param: (repeat from source)...
"""
__slots__ = []
code = base.opcodes['WRITEFILESHARE']
arg_format = tools.chain(['ci'], itertools.repeat('s'))
def has_var_args(self):
return True
[docs]class readsharesfromfile(base.IOInstruction):
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
:param: number of arguments to follow / number of shares plus two (int)
:param: starting position in number of shares from beginning (regint)
:param: destination for final position, -1 for eof reached, or -2 for file not found (regint)
:param: destination for share (sint)
:param: (repeat from destination for share)...
"""
__slots__ = []
code = base.opcodes['READFILESHARE']
arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw'))
def has_var_args(self):
return True
@base.gf2n
@base.vectorize
class raw_output(base.PublicFileIOInstruction):
r""" Raw output of register \verb|ci| to file. """
__slots__ = []
code = base.opcodes['RAWOUTPUT']
arg_format = ['c']
[docs]@base.vectorize
class intoutput(base.PublicFileIOInstruction):
""" Binary integer output.
:param: player (int)
:param: regint
"""
__slots__ = []
code = base.opcodes['INTOUTPUT']
arg_format = ['p','ci']
[docs]@base.vectorize
class floatoutput(base.PublicFileIOInstruction):
""" Binary floating-point output.
:param: player (int)
:param: significand (cint)
:param: exponent (cint)
:param: zero bit (cint)
:param: sign bit (cint)
"""
__slots__ = []
code = base.opcodes['FLOATOUTPUT']
arg_format = ['p','c','c','c','c']
[docs]@base.vectorize
class rand(base.Instruction):
""" Store insecure random value of specified length in clear integer
register (vector).
:param: destination (regint)
:param: length (regint)
"""
__slots__ = []
code = base.opcodes['RAND']
arg_format = ['ciw','ci']
###
### Integer operations
###
[docs]@base.vectorize
class ldint(base.Instruction):
""" Store (constant) immediate value in clear integer register (vector).
:param: destination (regint)
:param: immediate (int)
"""
__slots__ = []
code = base.opcodes['LDINT']
arg_format = ['ciw', 'i']
[docs]@base.vectorize
class addint(base.IntegerInstruction):
""" Clear integer register (vector) addition.
:param: result (regint)
:param: summand (regint)
:param: summand (regint)
"""
__slots__ = []
code = base.opcodes['ADDINT']
op = operator.add
[docs]@base.vectorize
class subint(base.IntegerInstruction):
""" Clear integer register (vector) subtraction.
:param: result (regint)
:param: first operand (regint)
:param: second operand (regint)
"""
__slots__ = []
code = base.opcodes['SUBINT']
op = operator.sub
[docs]@base.vectorize
class mulint(base.IntegerInstruction):
""" Clear integer register (element-wise vector) multiplication.
:param: result (regint)
:param: factor (regint)
:param: factor (regint)
"""
__slots__ = []
code = base.opcodes['MULINT']
op = operator.mul
[docs]@base.vectorize
class divint(base.IntegerInstruction):
""" Clear integer register (element-wise vector) division with floor
rounding.
:param: result (regint)
:param: dividend (regint)
:param: divisor (regint)
"""
__slots__ = []
code = base.opcodes['DIVINT']
op = operator.floordiv
[docs]@base.vectorize
class bitdecint(base.Instruction):
""" Clear integer bit decomposition.
:param: number of arguments to follow / number of bits minus one (int)
:param: source (regint)
:param: destination for least significant bit (regint)
:param: (destination for one bit higher)...
"""
__slots__ = []
code = base.opcodes['BITDECINT']
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
[docs]class incint(base.VectorInstruction):
""" Create incremental clear integer vector. For example, vector size 10,
base 1, increment 2, repeat 3, and wrap 2 produces the following::
(1, 1, 1, 3, 3, 3, 1, 1, 1, 3)
This is because the first number is always the :py:obj:`base`,
every number is repeated :py:obj:`repeat` times, after which
:py:obj:`increment` is added, and after :py:obj:`wrap` increments
the number returns to :py:obj:`base`.
:param: destination (regint)
:param: base (non-vector regint)
:param: increment (int)
:param: repeat (int)
:param: wrap (int)
"""
__slots__ = []
code = base.opcodes['INCINT']
arg_format = ['ciw', 'ci', 'i', 'i', 'i']
def __init__(self, *args, **kwargs):
assert len(args[1]) == 1
if len(args) == 3:
args = list(args) + [1, len(args[0])]
super(incint, self).__init__(*args, **kwargs)
[docs]class shuffle(base.VectorInstruction):
""" Randomly shuffles clear integer vector with public randomness.
:param: destination (regint)
:param: source (regint)
"""
__slots__ = []
code = base.opcodes['SHUFFLE']
arg_format = ['ciw','ci']
def __init__(self, *args, **kwargs):
super(shuffle, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
###
### Clear comparison instructions
###
[docs]@base.vectorize
class eqzc(base.UnaryComparisonInstruction):
""" Clear integer zero test. The result is 1 for true and 0 for false.
:param: destination (regint)
:param: operand (regint)
"""
__slots__ = []
code = base.opcodes['EQZC']
[docs]@base.vectorize
class ltzc(base.UnaryComparisonInstruction):
""" Clear integer less than zero test. The result is 1 for true
and 0 for false.
:param: destination (regint)
:param: operand (regint)
"""
__slots__ = []
code = base.opcodes['LTZC']
[docs]@base.vectorize
class ltc(base.IntegerInstruction):
""" Clear integer less-than comparison. The result is 1 if the
first operand is less and 0 otherwise.
:param: destination (regint)
:param: first operand (regint)
:param: second operand (regint)
"""
__slots__ = []
code = base.opcodes['LTC']
op = operator.lt
[docs]@base.vectorize
class gtc(base.IntegerInstruction):
""" Clear integer greater-than comparison. The result is 1 if the
first operand is greater and 0 otherwise.
:param: destination (regint)
:param: first operand (regint)
:param: second operand (regint)
"""
__slots__ = []
code = base.opcodes['GTC']
op = operator.gt
[docs]@base.vectorize
class eqc(base.IntegerInstruction):
""" Clear integer equality test. The result is 1 if the operands
are equal and 0 otherwise.
:param: destination (regint)
:param: first operand (regint)
:param: second operand (regint)
"""
__slots__ = []
code = base.opcodes['EQC']
op = operator.eq
###
### Jumps etc
###
[docs]class jmp(base.JumpInstruction):
""" Unconditional relative jump in the bytecode (compile-time parameter).
The parameter is added to the regular jump of one after every
instruction. This means that a jump of 0 results in a no-op while
a jump of -1 results in an infinite loop.
:param: number of instructions (int)
"""
__slots__ = []
code = base.opcodes['JMP']
arg_format = ['int']
jump_arg = 0
[docs]class jmpi(base.JumpInstruction):
""" Unconditional relative jump in the bytecode (run-time parameter).
The parameter is added to the regular jump of one after every
instruction. This means that a jump of 0 results in a no-op while
a jump of -1 results in an infinite loop.
:param: number of instructions (regint)
"""
__slots__ = []
code = base.opcodes['JMPI']
arg_format = ['ci']
jump_arg = 0
[docs]class jmpnz(base.JumpInstruction):
""" Conditional relative jump in the bytecode.
The parameter is added to the regular jump of one after every
instruction. This means that a jump of 0 results in a no-op while
a jump of -1 results in an infinite loop.
:param: condition (regint, only jump if not zero)
:param: number of instructions (int)
"""
__slots__ = []
code = base.opcodes['JMPNZ']
arg_format = ['ci', 'int']
jump_arg = 1
[docs]class jmpeqz(base.JumpInstruction):
""" Conditional relative jump in the bytecode.
The parameter is added to the regular jump of one after every
instruction. This means that a jump of 0 results in a no-op while
a jump of -1 results in an infinite loop.
:param: condition (regint, only jump if zero)
:param: number of instructions (int)
"""
__slots__ = []
code = base.opcodes['JMPEQZ']
arg_format = ['ci', 'int']
jump_arg = 1
###
### Conversions
###
[docs]@base.gf2n
@base.vectorize
class convint(base.Instruction):
""" Convert clear integer register (vector) to clear register (vector).
:param: destination (cint)
:param: source (regint)
"""
__slots__ = []
code = base.opcodes['CONVINT']
arg_format = ['cw', 'ci']
[docs]@base.vectorize
class convmodp(base.Instruction):
""" Convert clear integer register (vector) to clear register
(vector). If the bit length is zero, the unsigned conversion is
used, otherwise signed conversion is used. This makes a difference
when computing modulo a prime :math:`p`. Signed conversion of
:math:`p-1` results in -1 while signed conversion results in
:math:`(p-1) \mod 2^{64}`.
:param: destination (regint)
:param: source (cint)
:param: bit length (int)
"""
__slots__ = []
code = base.opcodes['CONVMODP']
arg_format = ['ciw', 'c', 'int']
def __init__(self, *args, **kwargs):
if len(args) == len(self.arg_format):
super(convmodp_class, self).__init__(*args)
return
bitlength = kwargs.get('bitlength')
bitlength = program.bit_length if bitlength is None else bitlength
if bitlength > 64:
raise CompilerError('%d-bit conversion requested ' \
'but integer registers only have 64 bits' % \
bitlength)
super(convmodp_class, self).__init__(*(args + (bitlength,)))
@base.vectorize
class gconvgf2n(base.Instruction):
""" Convert from clear modp register $c_j$ to integer register $ci_i$. """
__slots__ = []
code = base.opcodes['GCONVGF2N']
arg_format = ['ciw', 'cg']
###
### Other instructions
###
# rename 'open' to avoid conflict with built-in open function
[docs]@base.gf2n
@base.vectorize
class asm_open(base.VarArgsInstruction, base.DataInstruction):
""" Reveal secret registers (vectors) to clear registers (vectors).
:param: number of argument to follow (odd number)
:param: check after opening (0/1)
:param: destination (cint)
:param: source (sint)
:param: (repeat the last two)...
"""
__slots__ = []
code = base.opcodes['OPEN']
arg_format = tools.chain(['int'], tools.cycle(['cw','s']))
data_type = 'open'
def get_repeat(self):
return (len(self.args) - 1) // 2
def merge(self, other):
self.args[0] |= other.args[0]
self.args += other.args[1:]
[docs]@base.gf2n
class muls(base.VarArgsInstruction, base.DataInstruction, base.Ciscable):
""" (Element-wise) multiplication of secret registers (vectors).
:param: number of arguments to follow (multiple of four)
:param: vector size (int)
:param: result (sint)
:param: factor (sint)
:param: factor (sint)
:param: (repeat the last four)...
"""
__slots__ = []
code = base.opcodes['MULS']
arg_format = tools.cycle(['int','sw','s','s'])
data_type = 'triple'
is_vec = lambda self: True
def __init__(self, *args, **kwargs):
super(muls_class, self).__init__(*args, **kwargs)
for i in range(0, len(args), 4):
for j in range(3):
assert args[i + j + 1].size == args[i]
def get_repeat(self):
return sum(self.args[::4])
# def expand(self):
# s = [program.curr_block.new_reg('s') for i in range(9)]
# c = [program.curr_block.new_reg('c') for i in range(3)]
# triple(s[0], s[1], s[2])
# subs(s[3], self.args[1], s[0])
# subs(s[4], self.args[2], s[1])
# asm_open(c[0], s[3])
# asm_open(c[1], s[4])
# mulm(s[5], s[1], c[0])
# mulm(s[6], s[0], c[1])
# mulc(c[2], c[0], c[1])
# adds(s[7], s[2], s[5])
# adds(s[8], s[7], s[6])
# addm(self.args[0], s[8], c[2])
# compatibility
try:
vmuls = muls_class
muls_bak = muls
muls = lambda *args: muls_bak(args[0].size, *args)
vgmuls = gmuls_class = gmuls
gmuls = lambda *args: gmuls_class(args[0].size, *args)
except NameError:
pass
[docs]@base.gf2n
class mulrs(base.VarArgsInstruction, base.DataInstruction):
""" Constant-vector multiplication of secret registers.
:param: number of arguments to follow (multiple of four)
:param: vector size (int)
:param: result (sint)
:param: vector factor (sint)
:param: constant factor (sint)
:param: (repeat the last four)...
"""
__slots__ = []
code = base.opcodes['MULRS']
arg_format = tools.cycle(['int','sw','s','s'])
data_type = 'triple'
is_vec = lambda self: True
def __init__(self, res, x, y):
assert y.size == 1
assert res.size == x.size
base.Instruction.__init__(self, res.size, res, x, y)
def get_repeat(self):
return sum(self.args[::4])
[docs] def get_def(self):
return sum((arg.get_all() for arg in self.args[1::4]), [])
[docs] def get_used(self):
return sum((arg.get_all()
for arg in self.args[2::4] + self.args[3::4]), [])
[docs]@base.gf2n
@base.vectorize
class dotprods(base.VarArgsInstruction, base.DataInstruction,
base.DynFormatInstruction):
""" Dot product of secret registers (vectors).
Note that the vectorized version works element-wise.
:param: number of arguments to follow (int)
:param: twice the dot product length plus two (I know...)
:param: result (sint)
:param: first factor (sint)
:param: first factor (sint)
:param: second factor (sint)
:param: second factor (sint)
:param: (remaining factors)...
:param: (repeat from dot product length)...
"""
__slots__ = []
code = base.opcodes['DOTPRODS']
data_type = 'triple'
def __init__(self, *args):
flat_args = []
for i in range(0, len(args), 3):
res, x, y = args[i:i+3]
assert len(x) == len(y)
flat_args += [2 * len(x) + 2, res]
for x, y in zip(x, y):
flat_args += [x, y]
base.Instruction.__init__(self, *flat_args)
@classmethod
def dynamic_arg_format(self, args):
field = 'g' if self.is_gf2n() else ''
yield 'int'
for i, n in self.bases(args):
yield 's' + field + 'w'
assert n > 2
for j in range(n - 2):
yield 's' + field
yield 'int'
@property
def gf2n_arg_format(self):
return self.arg_format()
def get_repeat(self):
return sum(self.args[i] // 2 - 1
for i, n in self.bases(iter(self.args)))
[docs] def get_def(self):
return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
[docs] def get_used(self):
for i, n in self.bases(iter(self.args)):
for reg in self.args[i + 2:i + self.args[i]]:
yield reg
class matmul_base(base.DataInstruction):
data_type = 'triple'
is_vec = lambda self: True
def get_repeat(self):
return reduce(operator.mul, self.args[3:6])
[docs]class matmuls(matmul_base, base.Mergeable):
""" Secret matrix multiplication from registers. All matrices are
represented as vectors in row-first order.
:param: result (sint vector)
:param: first factor (sint vector)
:param: second factor (sint vector)
:param: number of rows in first factor and result (int)
:param: number of columns in first factor and rows in second factor (int)
:param: number of columns in second factor and result (int)
"""
code = base.opcodes['MATMULS']
arg_format = itertools.cycle(['sw','s','s','int','int','int'])
def get_repeat(self):
return sum(reduce(operator.mul, self.args[i + 3:i + 6])
for i in range(0, len(self.args), 6))
[docs]class matmulsm(matmul_base, base.Mergeable):
""" Secret matrix multiplication reading directly from memory.
:param: result (sint vector in row-first order)
:param: base address of first factor (regint value)
:param: base address of second factor (regint value)
:param: number of rows in first factor and result (int)
:param: number of columns in first factor and rows in second factor (int)
:param: number of columns in second factor and result (int)
:param: rows of first factor to use (regint vector, length as number of rows in first factor)
:param: columns of first factor to use (regint vector, length as number of columns in the first factor)
:param: rows of second factor to use (regint vector, length as number of columns in the first factor)
:param: columns of second factor to use (regint vector, length as number of columns in the second factor)
:param: total number of columns in the first factor, equal to used number of columns when all columns are used (int)
:param: total number of columns in the second factor, equal to used number of columns when all columns are used (int)
"""
code = base.opcodes['MATMULSM']
arg_format = itertools.cycle(['sw','ci','ci','int','int','int','ci','ci','ci','ci',
'int','int'])
def __init__(self, *args,
first_factor_base_addresses=None,
second_factor_base_addresses=None,
indices_values=None,
**kwargs):
matmul_base.__init__(self, *args, **kwargs)
for matmul_index in range(len(args) // 12):
for i in range(2):
assert args[12 * matmul_index + 6 + i].size == args[12 * matmul_index + 3 + i]
for i in range(2):
assert args[12 * matmul_index + 8 + i].size == args[12 * matmul_index + 4 + i]
# These are used to reconstruct that accessed memory addresses in the allocator.
self.first_factor_base_addresses = first_factor_base_addresses
self.second_factor_base_addresses = second_factor_base_addresses
self.indices_values = indices_values
if first_factor_base_addresses is not None:
assert len(first_factor_base_addresses) == len(second_factor_base_addresses)
if indices_values is not None:
assert len(indices_values) == 4 * len(first_factor_base_addresses)
def add_usage(self, req_node):
super(matmulsm, self).add_usage(req_node)
for i in range(0, len(self.args), 12):
req_node.increment(('matmul', (self.args[i + 3], self.args[i + 4], self.args[i + 5])), 1)
def get_repeat(self):
return sum(reduce(operator.mul, self.args[i + 3:i + 6])
for i in range(0, len(self.args), 12))
[docs]class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
""" Secret 2D convolution.
:param: number of arguments to follow (int)
:param: result (sint vector in row-first order)
:param: inputs (sint vector in row-first order)
:param: weights (sint vector in row-first order)
:param: output height (int)
:param: output width (int)
:param: input height (int)
:param: input width (int)
:param: weight height (int)
:param: weight width (int)
:param: stride height (int)
:param: stride width (int)
:param: number of channels (int)
:param: padding height (int)
:param: padding width (int)
:param: batch size (int)
:param: repeat from result...
"""
code = base.opcodes['CONV2DS']
arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
'int','int','int','int','int','int','int'])
data_type = 'triple'
is_vec = lambda self: True
def __init__(self, *args, **kwargs):
super(conv2ds, self).__init__(*args, **kwargs)
assert args[0].size == args[3] * args[4] * args[14]
assert args[1].size == args[5] * args[6] * args[11] * args[14]
assert args[2].size == args[7] * args[8] * args[11]
def get_repeat(self):
args = self.args
return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
args[i+11] * args[i+14] for i in range(0, len(args), 15))
def add_usage(self, req_node):
super(conv2ds, self).add_usage(req_node)
for i in range(0, len(self.args), 15):
args = self.args[i:i + 15]
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)
[docs]@base.vectorize
class trunc_pr(base.VarArgsInstruction):
""" Probabilistic truncation if supported by the protocol.
:param: number of arguments to follow (multiple of four)
:param: destination (sint)
:param: source (sint)
:param: bit length of source (int)
:param: number of bits to truncate (int)
"""
__slots__ = []
code = base.opcodes['TRUNC_PR']
arg_format = tools.cycle(['sw','s','int','int'])
[docs]class shuffle_base(base.DataInstruction):
n_relevant_parties = 2
@staticmethod
def logn(n):
return int(math.ceil(math.log(n, 2)))
@classmethod
def n_swaps(cls, n):
logn = cls.logn(n)
return logn * 2 ** logn - 2 ** logn + 1
def add_gen_usage(self, req_node, n):
# hack for unknown usage
req_node.increment(('bit', 'inverse'), float('inf'))
# minimal usage with two relevant parties
logn = self.logn(n)
n_switches = self.n_swaps(n)
for i in range(self.n_relevant_parties):
req_node.increment((self.field_type, 'input', i), n_switches)
# multiplications for bit check
req_node.increment((self.field_type, 'triple'),
n_switches * self.n_relevant_parties)
def add_apply_usage(self, req_node, n, record_size):
req_node.increment(('bit', 'inverse'), float('inf'))
logn = self.logn(n)
n_switches = self.n_swaps(n) * self.n_relevant_parties
if n != 2 ** logn:
record_size += 1
req_node.increment((self.field_type, 'triple'),
n_switches * record_size)
[docs]@base.gf2n
class secshuffle(base.VectorInstruction, shuffle_base):
""" Secure shuffling.
:param: destination (sint)
:param: source (sint)
"""
__slots__ = []
code = base.opcodes['SECSHUFFLE']
arg_format = ['sw','s','int']
def __init__(self, *args, **kwargs):
super(secshuffle_class, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
assert len(args[0]) > args[2]
def add_usage(self, req_node):
self.add_gen_usage(req_node, len(self.args[0]))
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
[docs]class gensecshuffle(shuffle_base):
""" Generate secure shuffle to bit used several times.
:param: destination (regint)
:param: size (int)
"""
__slots__ = []
code = base.opcodes['GENSECSHUFFLE']
arg_format = ['ciw','int']
def add_usage(self, req_node):
self.add_gen_usage(req_node, self.args[1])
[docs]class applyshuffle(base.VectorInstruction, shuffle_base):
""" Generate secure shuffle to bit used several times.
:param: destination (sint)
:param: source (sint)
:param: number of elements to be treated as one (int)
:param: handle (regint)
:param: reverse (0/1)
"""
__slots__ = []
code = base.opcodes['APPLYSHUFFLE']
arg_format = ['sw','s','int','ci','int']
def __init__(self, *args, **kwargs):
super(applyshuffle, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
assert len(args[0]) > args[2]
def add_usage(self, req_node):
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
[docs]class delshuffle(base.Instruction):
""" Delete secure shuffle.
:param: handle (regint)
"""
code = base.opcodes['DELSHUFFLE']
arg_format = ['ci']
[docs]class inverse_permutation(base.VectorInstruction, shuffle_base):
""" Calculate the inverse permutation of a secret permutation.
:param: destination (sint)
:param: source (sint)
"""
__slots__ = []
code = base.opcodes['INVPERM']
arg_format = ['sw', 's']
def __init__(self, *args, **kwargs):
super(inverse_permutation, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
def add_usage(self, req_node):
self.add_gen_usage(req_node, len(self.args[0]))
self.add_apply_usage(req_node, len(self.args[0]), 1)
[docs]class check(base.Instruction):
"""
Force MAC check in current thread and all idle thread if current
thread is the main thread.
"""
__slots__ = []
code = base.opcodes['CHECK']
arg_format = []
###
### CISC-style instructions
###
@base.gf2n
@base.vectorize
class sqrs(base.CISC):
""" Secret squaring $s_i = s_j \cdot s_j$. """
__slots__ = []
arg_format = ['sw', 's']
def expand(self):
if program.options.ring:
return muls(self.args[0], self.args[1], self.args[1])
s = [program.curr_block.new_reg('s') for i in range(6)]
c = [program.curr_block.new_reg('c') for i in range(2)]
square(s[0], s[1])
subs(s[2], self.args[1], s[0])
asm_open(False, c[0], s[2])
mulc(c[1], c[0], c[0])
mulm(s[3], self.args[1], c[0])
adds(s[4], s[3], s[3])
adds(s[5], s[1], s[4])
subml(self.args[0], s[5], c[1])
# placeholder for documentation
[docs]class cisc:
""" Meta instruction for emulation. This instruction is only generated
when using ``-K`` with ``compile.py``. The header looks as follows:
:param: number of arguments after name plus one
:param: name (16 bytes, zero-padded)
Currently, the following names are supported:
LTZ
Less than zero.
:param: number of arguments in this unit (must be 6)
:param: vector size
:param: result (sint)
:param: input (sint)
:param: bit length
:param: (ignored)
:param: (repeat)...
Trunc
Truncation.
:param: number of arguments in this unit (must be 8)
:param: vector size
:param: result (sint)
:param: input (sint)
:param: bit length
:param: number of bits to truncate
:param: (ignored)
:param: 0 for unsigned or 1 for signed
:param: (repeat)...
FPDiv
Fixed-point division. Division by zero results in zero without error.
:param: number of arguments in this unit (must be at least 7)
:param: vector size
:param: result (sint)
:param: dividend (sint)
:param: divisor (sint)
:param: (ignored)
:param: fixed-point precision
:param: (repeat)...
exp2_fx
Fixed-point power of two.
:param: number of arguments in this unit (must be at least 6)
:param: vector size
:param: result (sint)
:param: exponent (sint)
:param: (ignored)
:param: fixed-point precision
:param: (repeat)...
log2_fx
Fixed-point logarithm with base 2.
:param: number of arguments in this unit (must be at least 6)
:param: vector size
:param: result (sint)
:param: input (sint)
:param: (ignored)
:param: fixed-point precision
:param: (repeat)...
"""
code = base.opcodes['CISC']
# hack for circular dependency
from Compiler import comparison