"""
This module defines all types available in high-level programs. These
include basic types such as secret integers or floating-point numbers
and container types. A single instance of the former uses one or more
so-called registers in the virtual machine while the latter use the
so-called memory. For every register type, there is a corresponding
dedicated memory.
Registers are used for computation, allocated on an ongoing basis,
and thread-specific. The memory is allocated statically and shared
between threads. This means that memory-based types such as
:py:class:`Array` can be used to transfer information between threads.
Note that creating memory-based types outside the main thread is not
supported.
If viewing this documentation in processed form, many function signatures
appear generic because of the use of decorators. See the source code for the
correct signature.
Basic types
-----------
All basic types can be used as vectors, that is one instance representing
several values, with all operations being executed element-wise. For
example, the following computes ten multiplications of integers input
by party 0 and 1::
sint.get_input_from(0, size=10) * sint.get_input_from(1, size=10)
The following types are available in arithmetic circuits and with
reduced functionality in binary circuits.
.. autosummary::
:nosignatures:
sint
sfix
regint
The following types are only available in arithmetic circuits.
.. autosummary::
:nosignatures:
cint
cfix
sfloat
sgf2n
cgf2n
personal
Container types
---------------
.. autosummary::
:nosignatures:
MemValue
Array
Matrix
MultiArray
"""
from Compiler.program import Tape
from Compiler.exceptions import *
from Compiler.instructions import *
from Compiler.instructions_base import *
from .floatingpoint import two_power
from . import comparison, floatingpoint
import math
from . import util
from . import instructions
from .util import is_zero, is_one
import operator
from functools import reduce
import re
class ClientMessageType:
""" Enum to define type of message sent to external client. Each may be array of length n."""
# No client message type to be sent, for backwards compatibility - virtual machine relies on this value
NoType = 0
# 3 x sint x n
TripleShares = 1
# 1 x cint x n
ClearModpInt = 2
# 1 x regint x n
Int32 = 3
# 1 x cint (fixed point left shifted by precision) x n
ClearModpFix = 4
class MPCThread(object):
def __init__(self, target, name, args = [], runtime_arg = 0,
single_thread = False, finalize = True):
""" Create a thread from a callable object. """
if not callable(target):
raise CompilerError('Target %s for thread %s is not callable' % (target,name))
self.name = name
self.target = target
self.args = args
self.runtime_arg = runtime_arg
self.running = 0
self.tape_handle = program.new_tape(target, args, name,
single_thread=single_thread,
finalize=finalize)
self.tape = program.tapes[self.tape_handle]
self.run_handles = []
def start(self, runtime_arg = None):
self.running += 1
self.run_handles.append(program.run_tape(self.tape_handle, \
runtime_arg or self.runtime_arg))
def join(self):
if not self.running:
raise CompilerError('Thread %s is not running' % self.name)
self.running -= 1
program.join_tape(self.run_handles.pop(0))
def copy_doc(a, b):
try:
a.__doc__ = b.__doc__
except:
pass
def no_doc(operation):
def wrapper(*args, **kwargs):
return operation(*args, **kwargs)
return wrapper
def vectorize(operation):
def vectorized_operation(self, *args, **kwargs):
if len(args):
from .GC.types import bits
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
and not isinstance(args[0], bits) \
and args[0].size != self.size:
if min(args[0].size, self.size) == 1:
size = max(args[0].size, self.size)
self = self.expand_to_vector(size)
args = list(args)
args[0] = args[0].expand_to_vector(size)
else:
raise VectorMismatch('Different vector sizes of operands: %d/%d'
% (self.size, args[0].size))
set_global_vector_size(self.size)
try:
res = operation(self, *args, **kwargs)
finally:
reset_global_vector_size()
return res
copy_doc(vectorized_operation, operation)
return vectorized_operation
def vectorize_max(operation):
def vectorized_operation(self, *args, **kwargs):
size = self.size
for arg in args:
try:
size = max(size, arg.size)
except AttributeError:
pass
set_global_vector_size(size)
try:
res = operation(self, *args, **kwargs)
finally:
reset_global_vector_size()
return res
copy_doc(vectorized_operation, operation)
return vectorized_operation
def vectorized_classmethod(function):
def vectorized_function(cls, *args, **kwargs):
size = None
if 'size' in kwargs:
size = kwargs.pop('size')
if size is not None:
set_global_vector_size(size)
try:
res = function(cls, *args, **kwargs)
finally:
reset_global_vector_size()
else:
res = function(cls, *args, **kwargs)
return res
copy_doc(vectorized_function, function)
return classmethod(vectorized_function)
def vectorize_init(function):
def vectorized_init(*args, **kwargs):
size = None
if len(args) > 1 and isinstance(args[1], (_register, sfloat, cfix)):
size = args[1].size
if 'size' in kwargs and kwargs['size'] is not None \
and kwargs['size'] != size:
raise CompilerError('Mismatch in vector size')
if 'size' in kwargs and kwargs['size'] is not None:
size = kwargs['size']
if size is not None:
set_global_vector_size(size)
try:
res = function(*args, **kwargs)
finally:
reset_global_vector_size()
else:
res = function(*args, **kwargs)
return res
copy_doc(vectorized_init, function)
return vectorized_init
def set_instruction_type(operation):
def instruction_typed_operation(self, *args, **kwargs):
set_global_instruction_type(self.instruction_type)
try:
res = operation(self, *args, **kwargs)
finally:
reset_global_instruction_type()
return res
copy_doc(instruction_typed_operation, operation)
return instruction_typed_operation
def read_mem_value(operation):
def read_mem_operation(self, *args, **kwargs):
if len(args) > 0 and isinstance(args[0], MemValue):
args = (args[0].read(),) + args[1:]
return operation(self, *args, **kwargs)
copy_doc(read_mem_operation, operation)
return read_mem_operation
def type_comp(operation):
def type_check(self, other, *args, **kwargs):
if not isinstance(other, (type(self), int, regint, self.clear_type)):
return NotImplemented
return operation(self, other, *args, **kwargs)
copy_doc(type_check, operation)
return type_check
def inputmixed(*args):
# helper to cover both cases
if isinstance(args[-1], int):
instructions.inputmixed(*args)
else:
instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),)))
class _number(Tape._no_truth):
""" Number functionality. """
def square(self):
""" Square. """
return self * self
def __add__(self, other):
""" Optimized addition.
:param other: any compatible type """
if is_zero(other):
return self
else:
return self.add(other)
def __mul__(self, other):
""" Optimized multiplication.
:param other: any compatible type """
if is_zero(other):
return 0
elif is_one(other):
return self
else:
try:
return self.mul(other)
except VectorMismatch:
if type(self) != type(other) and 1 in (self.size, other.size):
# try reverse multiplication
return NotImplemented
else:
raise
__radd__ = __add__
__rmul__ = __mul__
@vectorize
def __pow__(self, exp):
""" Exponentation through square-and-multiply.
:param exp: any type allowing bit decomposition """
if isinstance(exp, int) and exp >= 0:
if exp == 0:
return self.__class__(1)
exp = bin(exp)[3:]
res = self
for i in exp:
res = res.square()
if i == '1':
res *= self
return res
elif isinstance(exp, _int):
bits = exp.bit_decompose()
powers = [self]
while len(powers) < len(bits):
powers.append(powers[-1] ** 2)
multiplicands = [b.if_else(p, 1) for b, p in zip(bits, powers)]
res = util.tree_reduce(operator.mul, multiplicands)
return res
else:
from .mpc_math import pow_fx
return pow_fx(self, exp)
def mul_no_reduce(self, other, res_params=None):
return self * other
def reduce_after_mul(self):
return self
def pow2(self, bit_length=None):
return 2**self
def min(self, other):
""" Minimum.
:param other: any compatible type """
return (self < other).if_else(self, other)
def max(self, other):
""" Maximum.
:param other: any compatible type """
return (self < other).if_else(other, self)
@classmethod
def dot_product(cls, a, b):
from Compiler.library import for_range_opt_multithread
res = MemValue(cls(0))
l = min(len(a), len(b))
xx = [a, b]
for i, x in enumerate((a, b)):
if not isinstance(x, Array):
xx[i] = Array(l, cls)
xx[i].assign(x)
aa, bb = xx
@for_range_opt_multithread(None, l)
def _(i):
res.iadd(res.value_type.conv(aa[i] * bb[i]))
return res.read()
def __abs__(self):
""" Absolute value. """
return (self < 0).if_else(-self, self)
@staticmethod
def popcnt_bits(bits):
return sum(bits)
def zero_if_not(self, condition):
return condition * self
def iadd(self, other):
""" Addition assignment. This uses :py:func:`update` internally. """
self.update(self + other)
class _int(Tape._no_truth):
""" Integer functionality. """
@staticmethod
def bit_adder(*args, **kwargs):
""" Binary adder in arithmetic circuits.
:param a: summand (list of 0/1 in compatible type)
:param b: summand (list of 0/1 in compatible type)
:param carry_in: input carry (default 0)
:param get_carry: add final carry to output
:returns: list of 0/1 in relevant type
"""
return intbitint.bit_adder(*args, **kwargs)
@staticmethod
def ripple_carry_adder(*args, **kwargs):
return intbitint.ripple_carry_adder(*args, **kwargs)
def if_else(self, a, b):
""" MUX on bit in arithmetic circuits::
print_ln('%s', sint(0).if_else(sint(2), sint(3)).reveal())
print_ln('%s', sint(1).if_else(sint(4), sint(5)).reveal())
This will output::
3
4
:param a/b: any type supporting the necessary operations
:return: a if :py:obj:`self` is 1, b if :py:obj:`self` is 0, undefined otherwise
:rtype: depending on operands, secret if any of them is """
if hasattr(a, 'for_mux'):
f, a, b = a.for_mux(b)
else:
f = lambda x: x
return f(self * (a - b) + b)
def cond_swap(self, a, b):
""" Swapping in arithmetic circuits.
:param a/b: any type supporting the necessary operations
:return: ``(a, b)`` if :py:obj:`self` is 0, ``(b, a)`` if :py:obj:`self` is 1, and undefined otherwise
:rtype: depending on operands, secret if any of them is """
prod = self * (a - b)
return a - prod, b + prod
def bit_xor(self, other):
""" Single-bit XOR in arithmetic circuits.
:param self/other: 0 or 1 (any compatible type)
:return: type depends on inputs (secret if any of them is) """
if util.is_constant(other):
if other:
return 1 - self
else:
return self
return self + other - 2 * self * other
def bit_or(self, other):
""" Single-bit OR in arithmetic circuits.
:param self/other: 0 or 1 (any compatible type)
:return: type depends on inputs (secret if any of them is) """
if util.is_constant(other):
if other:
return 1
else:
return 0
return self + other - self * other
def bit_and(self, other):
""" Single-bit AND in arithmetic circuits.
:param self/other: 0 or 1 (any compatible type)
:rtype: depending on inputs (secret if any of them is) """
return self * other
def bit_not(self):
""" Single-bit NOT in arithmetic circuits. """
return 1 - self
def half_adder(self, other):
""" Half adder in arithmetic circuits.
:param self/other: 0 or 1 (any compatible type)
:return: binary sum, carry
:rtype: depending on inputs, secret if any is """
carry = self * other
return self + other - 2 * carry, carry
@staticmethod
def long_one():
return 1
class _bit(Tape._no_truth):
""" Binary functionality. """
def bit_xor(self, other):
""" XOR in binary circuits.
:param self/other: 0 or 1 (any compatible type)
:rtype: depending on inputs (secret if any of them is) """
return self ^ other
def bit_and(self, other):
""" AND in binary circuits.
:param self/other: 0 or 1 (any compatible type)
:rtype: depending on inputs (secret if any of them is) """
return self & other
def bit_or(self, other):
""" OR in binary circuits.
:param self/other: 0 or 1 (any compatible type)
:return: type depends on inputs (secret if any of them is) """
return self ^ other - self & other
def bit_not(self):
""" NOT in binary circuits. """
return ~self
def half_adder(self, other):
""" Half adder in binary circuits.
:param self/other: 0 or 1 (any compatible type)
:return: binary sum, carry
:rtype: depending on inputs (secret if any of them is) """
return self ^ other, self & other
def carry_out(self, a, b):
s = a ^ b
return a ^ (s & (self ^ a))
def cond_swap(self, a, b):
prod = self * (a ^ b)
return a ^ prod, b ^ prod
class _gf2n(_bit):
r""" :math:`\mathrm{GF}(2^n)` functionality. """
def if_else(self, a, b):
r""" MUX in :math:`\mathrm{GF}(2^n)` circuits. Similar to :py:meth:`_int.if_else`. """
return b ^ self * self.hard_conv(a ^ b)
def cond_swap(self, a, b, t=None):
r""" Swapping in :math:`\mathrm{GF}(2^n)`. Similar to :py:meth:`_int.if_else`. """
prod = self * self.hard_conv(a ^ b)
res = a ^ prod, b ^ prod
if t is None:
return res
else:
return tuple(t.conv(r) for r in res)
def bit_xor(self, other):
r""" XOR in :math:`\mathrm{GF}(2^n)` circuits.
:param self/other: 0 or 1 (any compatible type)
:rtype: depending on inputs (secret if any of them is) """
return self ^ other
def bit_not(self):
return self ^ 1
class _structure(Tape._no_truth):
""" Interface for type-dependent container types. """
MemValue = classmethod(lambda cls, value: MemValue(cls.conv(value)))
""" Type-dependent memory value. """
@classmethod
def Array(cls, size, *args, **kwargs):
""" Type-dependent array. Example:
.. code::
a = sint.Array(10)
"""
return Array(size, cls, *args, **kwargs)
@classmethod
def Matrix(cls, rows, columns, *args, **kwargs):
""" Type-dependent matrix. Example:
.. code::
a = sint.Matrix(10, 10)
"""
return Matrix(rows, columns, cls, *args, **kwargs)
@classmethod
def Tensor(cls, shape, **kwargs):
"""
Type-dependent tensor of any dimension::
a = sfix.Tensor([10, 10])
"""
if len(shape) == 1:
return Array(shape[0], cls, **kwargs)
elif len(shape) == 2:
return Matrix(*shape, cls, **kwargs)
else:
return MultiArray(shape, cls, **kwargs)
@classmethod
def row_matrix_mul(cls, row, matrix, res_params=None):
res = type(row[0].mul_no_reduce(
matrix[0][0], res_params=res_params))(0, size=matrix.sizes[1])
@library.for_range_opt(len(row))
def _(k):
res.iadd(row[k].mul_no_reduce(matrix[k].get_vector(), res_params))
return res.reduce_after_mul()
@staticmethod
def mem_size():
return 1
def size_for_mem(self):
return self.size
@classmethod
def arg_type(cls):
if issubclass(cls, _register):
return cls.reg_type
if issubclass(cls, (cfix, sfix)):
return cls.int_type.reg_type
raise CompilerError('type not supported as argument: %s' % cls)
class _secret_structure(_structure):
@classmethod
def input_tensor_from(cls, player, shape):
""" Input tensor secretly from player.
:param player: int/regint/cint
:param shape: tensor shape
"""
res = cls.Tensor(shape)
res.input_from(player)
return res
@classmethod
def input_tensor_from_client(cls, client_id, shape):
""" Input tensor secretly from client.
:param client_id: client identifier (public)
:param shape: tensor shape
"""
res = cls.Tensor(shape)
res.assign_vector(cls.receive_from_client(1, client_id,
size=res.total_size())[0])
return res
@classmethod
def input_tensor_via(cls, player, content=None, shape=None, binary=True,
one_hot=False, skip_input=False, n_bytes=None):
"""
Input tensor-like data via a player. This overwrites the input
file for the relevant player. The following returns an
:py:class:`sint` matrix of dimension 2 by 2::
M = [[1, 2], [3, 4]]
sint.input_tensor_via(0, M)
Make sure to copy ``Player-Data/Input-P<player>-0`` or
``Player-Data/Input-Binary-P<player>-0`` if running
on another host.
:param player: player to input via (int)
:param content: nested Python list or numpy array (binary mode only) or
left out if not available
:param shape: shape if content not given
:param binary: binary mode (bool)
:param one_hot: one-hot encoding (bool)
"""
if program.curr_tape != program.tapes[0]:
raise CompilerError('only available in main thread')
if content is not None:
if isinstance(content, (_vectorizable, Tape.Register)):
raise CompilerError('cannot input data already in the VM')
requested_shape = shape
if binary:
import numpy
content = numpy.array(content)
if issubclass(cls, _fix):
min_k = \
math.ceil(math.log(abs(content).max() or 1, 2)) + cls.f + 1
if cls.k < min_k:
raise CompilerError(
"data outside fixed-point range, "
"use 'sfix.set_precision(%d, %d)'" % (cls.f, min_k))
if binary == 2:
t = numpy.double
else:
t = numpy.single
else:
if n_bytes == 1:
t = numpy.int8
else:
t = numpy.int64
if one_hot:
content = numpy.eye(content.max() + 1)[content]
content = content.astype(t)
f = program.get_binary_input_file(player)
f.write(content.tobytes())
f.flush()
shape = content.shape
else:
shape = []
tmp = content
while True:
try:
shape.append(len(tmp))
tmp = tmp[0]
except:
break
if not program.input_files.get(player, None):
program.input_files[player] = open(
'Player-Data/Input-P%d-0' % player, 'w')
f = program.input_files[player]
def traverse(content, level):
assert len(content) == shape[level]
if level == len(shape) - 1:
for x in content:
f.write(' ')
f.write(str(x))
else:
for x in content:
traverse(x, level + 1)
traverse(content, 0)
f.write('\n')
f.flush()
if requested_shape is not None and \
list(shape) != list(requested_shape):
raise CompilerError('content contradicts shape')
if not skip_input:
res = cls.Tensor(shape)
res.input_from(player, binary=binary, n_bytes=n_bytes)
return res
class _vec(Tape._no_truth):
def link(self, other):
assert len(self.v) == len(other.v)
for x, y in zip(self.v, other.v):
x.link(y)
class _register(Tape.Register, _number, _structure):
@staticmethod
def n_elements():
return 1
@classmethod
def new_vector(cls, size):
return cls(size=size)
def vector_size(self):
return self.size
@vectorized_classmethod
def conv(cls, val):
if isinstance(val, MemValue):
val = val.read()
if isinstance(val, cls):
return val
elif not isinstance(val, (_register, _vec)):
try:
return type(val)(cls.conv(v) for v in val)
except TypeError:
pass
except CompilerError:
pass
return cls(val)
@vectorized_classmethod
@read_mem_value
def hard_conv(cls, val):
if type(val) == cls:
return val
elif not isinstance(val, _register):
try:
return val.hard_conv_me(cls)
except AttributeError:
try:
return type(val)(cls.hard_conv(v) for v in val)
except (TypeError, CompilerError):
pass
return cls(val)
@vectorized_classmethod
@set_instruction_type
@read_mem_value
def _load_mem(cls, address, direct_inst, indirect_inst):
if isinstance(address, _register):
if address.size > 1:
size = address.size
else:
size = get_global_vector_size()
res = cls(size=size)
indirect_inst(res, cls._expand_address(address,
get_global_vector_size()))
else:
res = cls()
direct_inst(res, address)
return res
@staticmethod
def _expand_address(address, size):
address = regint.conv(address)
if size > 1 and address.size == 1:
res = regint(size=size)
incint(res, address, 1)
return res
else:
return address
@read_mem_value
@set_instruction_type
def _store_in_mem(self, address, direct_inst, indirect_inst):
if isinstance(address, _register):
indirect_inst(self, self._expand_address(address, self.size))
else:
direct_inst(self, address)
@classmethod
def prep_res(cls, other):
return cls()
@classmethod
def bit_compose(cls, bits):
""" Compose value from bits.
:param bits: iterable of any type implementing left shift """
return sum(cls.conv(b) << i for i,b in enumerate(bits))
@classmethod
def malloc(cls, size, creator_tape=None, **kwargs):
""" Allocate memory (statically).
:param size: compile-time (int) """
return program.malloc(size, cls, creator_tape=creator_tape, **kwargs)
@classmethod
def free(cls, addr):
program.free(addr, cls.reg_type)
@set_instruction_type
def __init__(self, reg_type, val, size):
from .GC.types import sbits
if isinstance(val, (tuple, list)):
size = len(val)
elif isinstance(val, sbits):
size = val.n
super(_register, self).__init__(reg_type, program.curr_tape, size=size)
if isinstance(val, int):
self.load_int(val)
elif isinstance(val, (tuple, list)):
for i, x in enumerate(val):
if util.is_constant(x):
self[i].load_int(x)
else:
self[i].load_other(x)
elif val is not None:
try:
self.load_other(val)
except:
raise CompilerError(
"cannot convert '%s' to '%s'" % (type(val), type(self)))
def _new_by_number(self, i, size=1):
res = type(self)(size=size)
res.i = i
res.program = self.program
return res
def sizeof(self):
return self.size
def extend(self, n):
return self
def expand_to_vector(self, size=None):
if size is None:
size = get_global_vector_size()
if self.size == size:
return self
assert self.size == 1
return self._expand_to_vector(size)
def _expand_to_vector(self, size):
res = type(self)(size=size)
for i in range(size):
self.mov(res[i], self)
return res
def copy_from_part(self, source, base, size):
set_global_vector_size(size)
self.mov(self, source.get_vector(base, size))
reset_global_vector_size()
@classmethod
def concat(cls, parts):
parts = list(parts)
res = cls(size=sum(len(part) for part in parts))
base = 0
for reg in parts:
set_global_vector_size(reg.size)
reg.mov(res.get_vector(base, reg.size), reg)
reset_global_vector_size()
base += reg.size
return res
class _arithmetic_register(_register):
""" Arithmetic circuit type. """
def __init__(self, *args, **kwargs):
if program.options.garbled:
raise CompilerError('functionality only available in arithmetic circuits')
super(_arithmetic_register, self).__init__(*args, **kwargs)
@classmethod
def get_type(cls, length):
return cls
@staticmethod
def two_power(n, size=None):
return floatingpoint.two_power(n)
def Norm(self, k, f, simplex_flag=False):
return library.Norm(self, k, f, simplex_flag=simplex_flag)
class _clear(_arithmetic_register):
""" Clear domain-dependent type. """
__slots__ = []
mov = staticmethod(movc)
@set_instruction_type
@vectorize
def load_other(self, val):
if isinstance(val, type(self)):
movc(self, val)
else:
self.convert_from(val)
@vectorize
@read_mem_value
def convert_from(self, val):
if not isinstance(val, regint):
val = regint(val)
convint(self, val)
@set_instruction_type
@vectorize
def print_reg(self, comment=''):
print_reg(self, comment)
@set_instruction_type
@vectorize
def print_reg_plain(self):
""" Output. """
print_reg_plain(self)
@set_instruction_type
@vectorize
def raw_output(self):
raw_output(self)
@vectorize
def binary_output(self, player=None):
""" Write 64-bit signed integer to
``Player-Data/Binary-Output-P<playerno>-<threadno>``.
:param player: only output on given player (default all)
"""
regint(self).binary_output(player)
@set_instruction_type
@read_mem_value
@vectorize
def clear_op(self, other, c_inst, ci_inst, reverse=False):
cls = self.__class__
res = self.prep_res(other)
if isinstance(other, regint):
other = cls(other)
if isinstance(other, cls):
if reverse:
c_inst(res, other, self)
else:
c_inst(res, self, other)
elif isinstance(other, int):
if self.in_immediate_range(other):
ci_inst(res, self, other)
else:
if reverse:
c_inst(res, cls(other), self)
else:
c_inst(res, self, cls(other))
else:
return NotImplemented
return res
@set_instruction_type
@read_mem_value
@vectorize
def coerce_op(self, other, inst, reverse=False):
cls = self.__class__
res = cls()
if isinstance(other, (int, regint)):
other = cls(other)
elif not isinstance(other, cls):
return NotImplemented
if reverse:
inst(res, other, self)
else:
inst(res, self, other)
return res
def add(self, other):
""" Addition of public values.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
return self.clear_op(other, addc, addci)
def mul(self, other):
""" Multiplication of public values.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
return self.clear_op(other, mulc, mulci)
def __sub__(self, other):
""" Subtraction of public values.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
return self.clear_op(other, subc, subci)
def __rsub__(self, other):
return self.clear_op(other, subc, subcfi, True)
__rsub__.__doc__ = __sub__.__doc__
def field_div(self, other):
""" Field division of public values. Not available for
computation modulo a power of two.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
try:
return other._rfield_div(self)
except AttributeError:
return self.clear_op(other, divc, divci)
def __and__(self, other):
""" Bit-wise AND of public values.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
return self.clear_op(other, andc, andci)
def __xor__(self, other):
""" Bit-wise XOR of public values.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
return self.clear_op(other, xorc, xorci)
def __or__(self, other):
""" Bit-wise OR of public values.
:param other: convertible type (at least same as :py:obj:`self` and regint/int) """
return self.clear_op(other, orc, orci)
__rand__ = __and__
__rxor__ = __xor__
__ror__ = __or__
def reveal(self):
""" Identity. """
return self
[docs]class cint(_clear, _int):
"""
Clear integer in same domain as secure computation (depends on
protocol). A number operators are supported (``+, -, *, /, //, **,
%, ^, &, |, ~, ==, !=, <<, >>``), returning either
:py:class:`cint` if the other operand is public (cint/regint/int)
or :py:class:`sint` if the other operand is
:py:class:`sint`. Comparison operators (``==, !=, <, <=, >, >=``)
are also supported, returning :py:func:`regint`. Comparisons and
``~`` require that the value is within the global bit length. The
same holds for :py:func:`abs`. ``/`` runs field division if the
modulus is a prime while ``//`` runs integer floor
division. ``**`` requires the exponent to be compile-time integer
or the base to be two.
This type is restricted to arithmetic circuits due to the fact
that only arithmetic protocols offer communication-less
public-private integer operations.
:param val: initialization (cint/regint/int/cgf2n or list thereof)
:param size: vector size (int), defaults to 1 or size of list
"""
__slots__ = []
instruction_type = 'modp'
reg_type = 'c'
[docs] @vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
""" Receive clear value(s) from client.
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param size: vector size (default 1)
:returns: cint (if n=1) or list of cint
"""
res = [cls() for i in range(n)]
readsocketc(client_id, get_global_vector_size(), *res)
if n == 1:
return res[0]
else:
return res
[docs] @classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of clear values to a client.
:param client_id: Client id (regint)
:param values: list of cint
"""
for value in values:
assert(value.size == values[0].size)
writesocketc(client_id, message_type, values[0].size, *values)
[docs] @vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls._load_mem(address, ldmc, ldmci)
[docs] def store_in_mem(self, address):
""" Store in memory by public address. """
self._store_in_mem(address, stmc, stmci)
@staticmethod
def in_immediate_range(value, regint=False):
if program.options.ring and not regint:
if abs(value) > 2 ** int(program.options.ring):
raise CompilerError('value outside range for domain')
return value < 2**31 and value >= -2**31
@vectorize_init
def __init__(self, val=None, size=None):
super(cint, self).__init__('c', val=val, size=size)
@vectorize
def load_int(self, val):
if val:
# +1 for sign
bit_length = 1 + int(math.ceil(math.log(abs(val))))
if program.options.ring:
assert(bit_length <= int(program.options.ring))
elif program.options.field:
program.curr_tape.require_bit_length(bit_length)
if self.in_immediate_range(val):
ldi(self, val)
else:
max = 2**31 - 1
sign = abs(val) // val
val = abs(val)
chunks = []
while val:
mod = val % max
val = (val - mod) // max
chunks.append(mod)
sum = cint(sign * chunks.pop())
for i,chunk in enumerate(reversed(chunks)):
sum *= max
if i == len(chunks) - 1:
addci(self, sum, sign * chunk)
elif chunk:
sum += sign * chunk
def load_other(self, val):
if isinstance(val, cfix):
val = val.v.round(val.k, val.f)
super(cint, self).load_other(val)
[docs] @vectorize
def to_regint(self, n_bits=64, dest=None):
""" Convert to regint.
:param n_bits: bit length (int)
:return: regint """
dest = regint() if dest is None else dest
convmodp(dest, self, bitlength=n_bits)
return dest
def __mod__(self, other):
""" Clear modulo.
:param other: cint/regint/int """
return self.clear_op(other, modc, modci)
def __rmod__(self, other):
""" Clear modulo.
:param other: cint/regint/int """
return self.coerce_op(other, modc, True)
def __floordiv__(self, other):
return self.coerce_op(other, floordivc)
def __rfloordiv__(self, other):
return self.coerce_op(other, floordivc, True)
def __truediv__(self, other):
""" Clear fixed-point division.
:param other: any compatible type """
if isinstance(other, cint):
return other.__rtruediv__(self)
try:
return cfix._new(self) / cfix._new(cint(other))
except:
return NotImplemented
def __rtruediv__(self, other):
return cfix._new(other) / cfix._new(self)
[docs] @vectorize
def less_than(self, other, bit_length):
""" Clear comparison for particular bit length.
:param other: cint/regint/int
:param bit_length: signed bit length of inputs
:return: 0/1 (regint), undefined if inputs outside range """
if not isinstance(other, (cint, regint, int)):
return NotImplemented
if bit_length <= 64:
return regint(self) < regint(other)
else:
sint.require_bit_length(bit_length + 1)
diff = self - other
diff += 1 << bit_length
shifted = diff >> bit_length
res = 1 - regint(shifted & 1)
return res
def __lt__(self, other):
""" Clear comparison.
:param other: cint/regint/int
:return: 0/1 (regint) """
return self.less_than(other, program.bit_length)
@vectorize
def __gt__(self, other):
if isinstance(other, (cint, regint, int)):
return self.conv(other) < self
else:
return NotImplemented
def __le__(self, other):
return 1 - (self > other)
def __ge__(self, other):
return 1 - (self < other)
for op in __gt__, __le__, __ge__:
op.__doc__ = __lt__.__doc__
del op
@vectorize
def __eq__(self, other):
""" Clear equality test.
:param other: cint/regint/int
:return: 0/1 (regint) """
if not isinstance(other, (_clear, regint, int)):
return NotImplemented
res = 1
remaining = program.bit_length
while remaining > 0:
if isinstance(other, cint):
o = other.to_regint(min(remaining, 64))
else:
o = other % 2 ** 64
res *= (self.to_regint(min(remaining, 64)) == o)
self >>= 64
other >>= 64
remaining -= 64
return res
def __ne__(self, other):
return 1 - (self == other)
equal = lambda self, other, *args, **kwargs: self.__eq__(other)
def __lshift__(self, other):
""" Clear left shift.
:param other: cint/regint/int """
return self.clear_op(other, shlc, shlci)
def __rshift__(self, other):
""" Clear right shift.
:param other: cint/regint/int """
return self.clear_op(other, shrc, shrci)
def __neg__(self):
""" Clear negation. """
return 0 - self
def __abs__(self):
""" Clear absolute. """
return (self >= 0).if_else(self, -self)
@vectorize
def __invert__(self):
""" Clear inversion using global bit length. """
res = cint()
notc(res, self, program.bit_length)
return res
def __rpow__(self, base):
""" Clear power of two.
:param other: 2 """
if base == 2:
return 1 << self
else:
return NotImplemented
@vectorize
def __rlshift__(self, other):
""" Clear shift.
:param other: cint/regint/int """
return cint(other) << self
@vectorize
def __rrshift__(self, other):
""" Clear shift.
:param other: cint/regint/int """
return cint(other) >> self
[docs] @read_mem_value
def mod2m(self, other, bit_length=None, signed=None):
""" Clear modulo a power of two.
:param other: cint/regint/int """
return self % 2**other
[docs] @read_mem_value
def right_shift(self, other, bit_length=None):
""" Clear shift.
:param other: cint/regint/int """
return self >> other
def round(self, k, m, nearest=None, signed=False):
if signed:
self += 2 ** (k - 1)
self += 2 ** (m - 1)
res = self >> m
if signed:
res -= 2 ** (k - m - 1)
return res
@read_mem_value
def greater_than(self, other, bit_length=None):
return self > other
[docs] @vectorize
def bit_decompose(self, bit_length=None, maybe_mixed=None):
""" Clear bit decomposition.
:param bit_length: number of bits (default is global bit length)
:return: list of cint """
if bit_length == 0:
return []
bit_length = bit_length or program.bit_length
return floatingpoint.bits(self, bit_length)
[docs] @vectorize
def legendre(self):
""" Clear Legendre symbol computation. """
res = cint()
legendrec(res, self)
return res
[docs] @vectorize
def digest(self, num_bytes):
""" Clear hashing (libsodium default). """
res = cint()
digestc(res, self, num_bytes)
return res
[docs] def print_if(self, string):
""" Output if value is non-zero.
:param string: bytearray """
cond_print_str(self, string)
def output_if(self, cond):
cond_print_plain(self.conv(cond), self, cint(0, size=self.size))
[docs]class cgf2n(_clear, _gf2n):
r"""
Clear :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A
number operators are supported (``+, -, *, /, **, ^, &, |, ~, ==,
!=, <<, >>``), returning either :py:class:`cgf2n` if the other
operand is public (cgf2n/regint/int) or :py:class:`sgf2n` if the
other operand is secret. The following operators require the other
operand to be a compile-time integer: ``**, <<, >>``. ``*, /, **`` refer
to field multiplication and division.
:param val: initialization (cgf2n/cint/regint/int or list thereof)
:param size: vector size (int), defaults to 1 or size of list
"""
__slots__ = []
instruction_type = 'gf2n'
reg_type = 'cg'
[docs] @classmethod
def bit_compose(cls, bits, step=None):
r""" Clear :math:`\mathrm{GF}(2^n)` bit composition.
:param bits: list of cgf2n
:param step: set every :py:obj:`step`-th bit in output (defaults to 1) """
size = bits[0].size
res = cls(size=size)
vgbitcom(size, res, step or 1, *bits)
return res
[docs] @vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls._load_mem(address, gldmc, gldmci)
[docs] def store_in_mem(self, address):
""" Store in memory by public address. """
self._store_in_mem(address, gstmc, gstmci)
@staticmethod
def in_immediate_range(value):
return value < 2**32 and value >= 0
def __init__(self, val=None, size=None):
super(cgf2n, self).__init__('cg', val=val, size=size)
@vectorize
def load_int(self, val):
if val < 0:
raise CompilerError('Negative GF2n immediate')
if self.in_immediate_range(val):
gldi(self, val)
else:
chunks = []
while val:
mod = val % 2**32
val >>= 32
chunks.append(mod)
sum = cgf2n(chunks.pop())
for i,chunk in enumerate(reversed(chunks)):
sum <<= 32
if i == len(chunks) - 1:
gaddci(self, sum, chunk)
elif chunk:
sum += chunk
def __neg__(self):
""" Identity. """
return self
__truediv__ = _clear.field_div
def __rtruediv__(self, other):
return self.coerce_op(other, divc, True)
@vectorize
def __invert__(self):
""" Clear bit-wise inversion. """
res = cgf2n()
gnotc(res, self)
return res
@vectorize
def __lshift__(self, other):
""" Left shift.
:param other: compile-time (int) """
if isinstance(other, int):
res = cgf2n()
gshlci(res, self, other)
return res
else:
return NotImplemented
@vectorize
def __rshift__(self, other):
""" Right shift.
:param other: compile-time (int) """
if isinstance(other, int):
res = cgf2n()
gshrci(res, self, other)
return res
else:
return NotImplemented
def __eq__(self, other):
if isinstance(other, (cgf2n, int)):
return (regint(self) == regint(other)) * \
(regint(self >> 64) == regint(other >> 64))
else:
return NotImplemented
def __ne__(self, other):
return 1 - (self == other)
[docs] @vectorize
def bit_decompose(self, bit_length=None, step=None):
r""" Clear bit decomposition.
:param bit_length: number of bits (defaults to global :math:`\mathrm{GF}(2^n)` bit length)
:param step: extract every :py:obj:`step`-th bit (defaults to 1) """
bit_length = bit_length or program.galois_length
step = step or 1
res = [type(self)() for _ in range(bit_length // step)]
gbitdec(self, step, *res)
return res
[docs]class regint(_register, _int):
"""
Clear 64-bit integer.
Unlike :py:class:`cint` this is always a 64-bit integer. The type
supports the following operations with :py:class:`regint` or
Python integers, always returning :py:class:`regint`: ``+, -, *, %,
/, //, **, ^, &, |, <<, >>, ==, !=, <, <=, >, >=``. For operations
with other types, see the respective descriptions. Both ``/`` and
``//`` stand for floor division.
:param val: initialization (cint/cgf2n/regint/int or list thereof)
:param size: vector size (int), defaults to 1 or size of list
"""
__slots__ = []
reg_type = 'ci'
instruction_type = 'modp'
mov = staticmethod(movint)
[docs] @vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls._load_mem(address, ldmint, ldminti)
[docs] def store_in_mem(self, address):
""" Store in memory by public address. """
self._store_in_mem(address, stmint, stminti)
[docs] @vectorized_classmethod
def pop(cls):
""" Pop from stack. Made obsolete by :py:func:`update`. """
res = cls()
popint(res)
return res
[docs] @vectorized_classmethod
def push(cls, value):
""" Push to stack. Made obsolete by :py:func:`update`.
:param value: any convertible type """
pushint(cls.conv(value))
[docs] @vectorized_classmethod
def get_random(cls, bit_length):
""" Public insecure randomness.
:param bit_length: number of bits (int)
:param size: vector size (int, default 1)
"""
if isinstance(bit_length, int):
bit_length = regint(bit_length)
res = cls()
rand(res, bit_length)
return res
[docs] @classmethod
def inc(cls, size, base=0, step=1, repeat=1, wrap=None):
"""
Produce :py:class:`regint` vector with certain patterns.
This is particularly useful for :py:meth:`SubMultiArray.direct_mul`.
:param size: Result size
:param base: First value
:param step: Increase step
:param repeat: Repeate this many times
:param wrap: Start over after this many increases
The following produces (1, 1, 1, 3, 3, 3, 5, 5, 5, 7)::
regint.inc(10, 1, 2, 3)
"""
res = regint(size=size)
if wrap is None:
wrap = size
incint(res, cls.conv(base, size=1), step, repeat, wrap)
return res
[docs] @vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
""" Receive clear integer value(s) from client.
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param size: vector size (default 1)
:returns: regint (if n=1) or list of regint
"""
res = [cls() for i in range(n)]
readsocketint(client_id, get_global_vector_size(), *res)
if n == 1:
return res[0]
else:
return res
[docs] @classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of clear integers to a client.
:param client_id: Client id (regint)
:param values: list of regint
"""
for value in values:
assert(value.size == values[0].size)
writesocketint(client_id, message_type, values[0].size, *values)
@vectorize_init
def __init__(self, val=None, size=None):
super(regint, self).__init__(self.reg_type, val=val, size=size)
def load_int(self, val):
if cint.in_immediate_range(val, regint=True):
ldint(self, val)
else:
lower = val % 2**32
upper = val >> 32
if lower >= 2**31:
lower -= 2**32
upper += 1
addint(self, regint(upper) * regint(2**16)**2, regint(lower))
@read_mem_value
def load_other(self, val):
if isinstance(val, cgf2n):
gconvgf2n(self, val)
elif isinstance(val, regint):
addint(self, val, regint(0))
else:
try:
val.to_regint(dest=self)
except AttributeError:
raise CompilerError("Cannot convert '%s' to integer" % \
type(val))
def expand_to_vector(self, size=None):
if size is None:
size = get_global_vector_size()
if self.size == size:
return self
assert self.size == 1
return self.inc(size, self, 0)
@vectorize
@read_mem_value
def int_op(self, other, inst, reverse=False):
if isinstance(other, (int, regint)):
other = self.conv(other)
else:
return NotImplemented
res = regint()
if reverse:
inst(res, other, self)
else:
inst(res, self, other)
return res
def add(self, other):
""" Clear addition.
:param other: regint/cint/int """
return self.int_op(other, addint)
def __sub__(self, other):
""" Clear subtraction.
:param other: regint/cint/int """
return self.int_op(other, subint)
def __rsub__(self, other):
return self.int_op(other, subint, True)
__rsub__.__doc__ = __sub__.__doc__
def mul(self, other):
""" Clear multiplication.
:param other: regint/cint/int """
return self.int_op(other, mulint)
def __neg__(self):
""" Clear negation. """
return 0 - self
def __floordiv__(self, other):
""" Clear integer division (rounding to floor).
:param other: regint/cint/int """
if util.is_constant(other) and other >= 2 ** 64:
return 0
return self.int_op(other, divint)
def __rfloordiv__(self, other):
return self.int_op(other, divint, True)
__rfloordiv__.__doc__ = __floordiv__.__doc__
def __truediv__(self, other):
if isinstance(other, _gf2n):
return NotImplemented
else:
return cint(self) / other
def __rtruediv__(self, other):
return other / cint(self)
def __mod__(self, other):
""" Clear modulo computation.
:param other: regint/cint/int """
if util.is_constant(other) and other >= 2 ** 64:
return self
return self - (self // other) * other
def __rmod__(self, other):
""" Clear modulo computation.
:param other: regint/cint/int """
return regint(other) % self
def __rpow__(self, other):
""" Clear power of two computation.
:param other: regint/cint/int
:rtype: cint """
return other**cint(self)
def __eq__(self, other):
""" Clear comparison.
:param other: regint/cint/int
:return: 0/1 """
return self.int_op(other, eqc, False)
def __ne__(self, other):
return 1 - (self == other)
def __lt__(self, other):
return self.int_op(other, ltc, False)
def __gt__(self, other):
return self.int_op(other, gtc, False)
def __le__(self, other):
return 1 - (self > other)
def __ge__(self, other):
return 1 - (self < other)
for op in __le__, __lt__, __ge__, __gt__, __ne__:
op.__doc__ = __eq__.__doc__
del op
def cint_op(self, other, op):
if isinstance(other, regint):
return regint(op(cint(self), other))
else:
return NotImplemented
def __lshift__(self, other):
""" Clear shift.
:param other: regint/cint/int """
if isinstance(other, int):
return self * 2**other
else:
return self.cint_op(other, operator.lshift)
def __rshift__(self, other):
if isinstance(other, int):
return self // 2**other
else:
return self.cint_op(other, operator.rshift)
def __rlshift__(self, other):
return regint(other << cint(self))
def __rrshift__(self, other):
return regint(other >> cint(self))
for op in __rshift__, __rlshift__, __rrshift__:
op.__doc__ = __lshift__.__doc__
del op
def __and__(self, other):
""" Clear bit-wise AND.
:param other: regint/cint/int """
return self.cint_op(other, operator.and_)
def __or__(self, other):
""" Clear bit-wise OR.
:param other: regint/cint/int """
return self.cint_op(other, operator.or_)
def __xor__(self, other):
""" Clear bit-wise XOR.
:param other: regint/cint/int """
return self.cint_op(other, operator.xor)
__rand__ = __and__
__ror__ = __or__
__rxor__ = __xor__
[docs] def mod2m(self, *args, **kwargs):
""" Clear modulo a power of two.
:rtype: cint """
return cint(self).mod2m(*args, **kwargs)
[docs] @vectorize
def bit_decompose(self, bit_length=None):
""" Clear bit decomposition.
:param bit_length: number of bits (defaults to global bit length)
:return: list of regint """
bit_length = bit_length or min(64, program.bit_length)
if bit_length > 64:
raise CompilerError('too many bits demanded')
res = [regint() for i in range(bit_length)]
bitdecint(self, *res)
return res
[docs] @staticmethod
def bit_compose(bits):
""" Clear bit composition.
:param bits: list of regint/cint/int """
two = regint(2)
res = 0
for bit in reversed(bits):
res *= two
res += bit
return res
[docs] def shuffle(self):
""" Returns insecure shuffle of vector. """
res = regint(size=len(self))
shuffle(res, self)
return res
[docs] def reveal(self):
""" Identity. """
return self
[docs] def print_reg_plain(self):
""" Output. """
print_int(self)
[docs] def print_if(self, string):
""" Output string if value is non-zero.
:param string: Python string """
self._condition().print_if(string)
def output_if(self, cond):
self._condition().output_if(cond)
def _condition(self):
if program.options.binary:
from .GC.types import cbits
return cbits.get_type(64)(self)
else:
return cint(self)
[docs] def binary_output(self, player=None):
""" Write 64-bit signed integer to
``Player-Data/Binary-Output-P<playerno>-<threadno>``.
:param player: only output on given player (default all)
"""
if player == None:
player = -1
if not util.is_constant(player):
raise CompilerError('Player number must be known at compile time')
intoutput(player, self)
[docs]class localint(Tape._no_truth):
""" Local integer that must prevented from leaking into the secure
computation. Uses regint internally.
:param value: initialization, convertible to regint
"""
def __init__(self, value=None):
self._v = regint(value)
self.size = 1
[docs] def output(self):
""" Output. """
self._v.print_reg_plain()
""" Local comparison. """
__lt__ = lambda self, other: localint(self._v < other)
__le__ = lambda self, other: localint(self._v <= other)
__gt__ = lambda self, other: localint(self._v > other)
__ge__ = lambda self, other: localint(self._v >= other)
__eq__ = lambda self, other: localint(self._v == other)
__ne__ = lambda self, other: localint(self._v != other)
__add__ = lambda self, other: localint(self._v + other)
__radd__ = lambda self, other: localint(self._v + other)
[docs]class personal(Tape._no_truth):
""" Value known to one player. Supports operations with public
values and personal values known to the same player. Can be used
with :py:func:`~Compiler.library.print_ln_to`. It is possible to
convert to secret types like :py:class:`sint`.
:param player: player (int)
:param value: cleartext value (cint, cfix, cfloat) or array thereof
"""
def __init__(self, player, value):
assert value is not NotImplemented
assert not isinstance(value, _secret)
while isinstance(value, personal):
assert player == value.player
value = value._v
self.player = player
self._v = value
[docs] @classmethod
def read_int(cls, player, n_bytes=None):
""" Read integer from
``Player-Data/Input-Binary-P<player>-<threadnum>`` only on
party :py:obj:`player`.
:param player: player (int)
:return: personal cint
"""
tmp = cint()
fixinput(player, tmp, n_bytes or 0, 0)
return cls(player, tmp)
[docs] @classmethod
def read_fix(cls, player, f, k, precision):
""" Read fixed-point value from
``Player-Data/Input-Binary-P<player>-<threadnum>`` only on
party :py:obj:`player`.
:param player: player (int)
:param f: fixed-point precision (int)
:param k: fixed-point length (int)
:param precision: input precision (1: single, 2: double)
:return: personal cfix
"""
assert precision in (1, 2)
tmp = cint()
fixinput(player, tmp, f, precision)
return cls(player, cfix._new(tmp, f=f, k=k))
@classmethod
def read_int_from_socket(cls, player, socket, size=1):
return cls.read_from_socket(player, socket, cint, size)
@classmethod
def read_fix_from_socket(cls, player, socket, size=1):
return cls.read_from_socket(player, socket, cfix, size)
@classmethod
def read_from_socket(cls, player, socket, type, size):
tmp = type(size=size)
@library.if_(player == library.get_player_id()._v)
def _():
tmp.link(type.read_from_socket(socket, size=size))
return cls(player, tmp)
[docs] def binary_output(self):
""" Write binary output to
``Player-Data/Binary-Output-P<playerno>-<threadno>`` if
supported by underlying type. Player must be known at compile time."""
self._v.binary_output(self.player)
[docs] def reveal_to(self, player):
""" Pass personal value to another player. """
if isinstance(self._v, Array):
source = self._v[:]
else:
source = self._v
source = cint.conv(source)
res = cint(size=source.size)
sendpersonal(source.size, player, res, self.player, source)
if isinstance(self._v, Array):
res = Array.create_from(res)
return personal(player, res)
[docs] def bit_decompose(self, length=None):
""" Bit decomposition.
:param length: number of bits
"""
return [personal(self.player, x) for x in self._v.bit_decompose(length)]
def _san(self, other):
if isinstance(other, personal):
assert self.player == other.player
return self._v
def _div_san(self):
return self._v.conv((library.get_player_id() == self.player)._v).if_else(self._v, 1)
def __setitem__(self, index, value):
self._san(value)
self._v[index] = value
__getitem__ = lambda self, index: personal(self.player, self._v[index])
__add__ = lambda self, other: personal(self.player, self._san(other) + other)
__sub__ = lambda self, other: personal(self.player, self._san(other) - other)
__mul__ = lambda self, other: personal(self.player, self._san(other) * other)
__pow__ = lambda self, other: personal(self.player, self._san(other) ** other)
__truediv__ = lambda self, other: personal(self.player, self._san(other) / other)
__floordiv__ = lambda self, other: personal(self.player, self._san(other) // other)
__mod__ = lambda self, other: personal(self.player, self._san(other) % other)
__lt__ = lambda self, other: personal(self.player, self._san(other) < other)
__gt__ = lambda self, other: personal(self.player, self._san(other) > other)
__le__ = lambda self, other: personal(self.player, self._san(other) <= other)
__ge__ = lambda self, other: personal(self.player, self._san(other) >= other)
__eq__ = lambda self, other: personal(self.player, self._san(other) == other)
__ne__ = lambda self, other: personal(self.player, self._san(other) != other)
__and__ = lambda self, other: personal(self.player, self._san(other) & other)
__xor__ = lambda self, other: personal(self.player, self._san(other) ^ other)
__or__ = lambda self, other: personal(self.player, self._san(other) | other)
__lshift__ = lambda self, other: personal(self.player, self._san(other) << other)
__rshift__ = lambda self, other: personal(self.player, self._san(other) >> other)
__neg__ = lambda self: personal(self.player, -self._v)
__radd__ = lambda self, other: personal(self.player, other + self._v)
__rsub__ = lambda self, other: personal(self.player, other - self._v)
__rmul__ = lambda self, other: personal(self.player, other * self._v)
__rand__ = lambda self, other: personal(self.player, other & self._v)
__rxor__ = lambda self, other: personal(self.player, other ^ self._v)
__ror__ = lambda self, other: personal(self.player, other | self._v)
__rlshift__ = lambda self, other: personal(self.player, other << self._v)
__rrshift__ = lambda self, other: personal(self.player, other >> self._v)
__rtruediv__ = lambda self, other: personal(self.player, other / self._div_san())
__rfloordiv__ = lambda self, other: personal(self.player, other // self._div_san())
__rmod__ = lambda self, other: personal(self.player, other % self._div_san())
class longint:
def __init__(self, value, length=None, n_limbs=None):
assert length is None or n_limbs is None
if isinstance(value, longint):
if n_limbs is None:
n_limbs = int(math.ceil(length / 64))
assert n_limbs <= len(value.v)
self.v = value.v[:n_limbs]
elif isinstance(value, list):
assert length is None
self.v = value[:]
else:
if length is None:
length = 64 * n_limbs
if isinstance(value, int):
self.v = [(value >> i) for i in range(0, length, 64)]
else:
self.v = [(value >> i).to_regint(0)
for i in range(0, length, 64)]
def coerce(self, other):
return longint(other, n_limbs=len(self.v))
def __eq__(self, other):
return reduce(operator.mul, (x == y for x, y in
zip(self.v, self.coerce(other).v)))
def __add__(self, other):
other = self.coerce(other)
assert len(self.v) == len(other.v)
res = []
carry = 0
for x, y in zip(self.v, other.v):
res.append(x + y + carry)
carry = util.if_else(carry, (res[-1] + 2 ** 63) <= (x + 2 ** 63),
(res[-1] + 2 ** 63) < (x + 2 ** 63))
return longint(res)
__radd__ = __add__
def __sub__(self, other):
return self + -other
def bit_decompose(self, bit_length):
assert bit_length <= 64 * len(self.v)
res = []
for x in self.v:
res += x.bit_decompose(64)
return res[:bit_length]
class _secret(_arithmetic_register, _secret_structure):
__slots__ = []
mov = staticmethod(set_instruction_type(movs))
PreOR = staticmethod(lambda l: floatingpoint.PreORC(l))
PreOp = staticmethod(lambda op, l: floatingpoint.PreOpL(op, l))
@vectorized_classmethod
@set_instruction_type
def get_input_from(cls, player):
""" Secret input from player.
:param player: public (regint/cint/int)
:param size: vector size (int, default 1)
"""
res = cls()
asm_input(res, player)
return res
@vectorized_classmethod
@set_instruction_type
def get_random_triple(cls):
""" Secret random triple according to security model.
:return: :math:`(a, b, ab)`
:param size: vector size (int, default 1)
"""
res = (cls(), cls(), cls())
triple(*res)
return res
@vectorized_classmethod
@set_instruction_type
def get_random_bit(cls):
""" Secret random bit according to security model.
:return: 0/1 50-50
:param size: vector size (int, default 1)
"""
res = cls()
bit(res)
return res
@vectorized_classmethod
@set_instruction_type
def get_random_square(cls):
""" Secret random square according to security model.
:return: :math:`(a, a^2)`
:param size: vector size (int, default 1)
"""
res = (cls(), cls())
square(*res)
return res
@vectorized_classmethod
@set_instruction_type
def get_random_inverse(cls):
""" Secret random inverse tuple according to security model.
:return: :math:`(a, a^{-1})`
:param size: vector size (int, default 1)
"""
res = (cls(), cls())
inverse(*res)
return res
@vectorized_classmethod
@set_instruction_type
def get_random_input_mask_for(cls, player):
""" Secret random input mask according to security model.
:return: mask (sint), mask (personal cint)
:param size: vector size (int, default 1)
"""
res = cls(), personal(player, cls.clear_type())
inputmask(res[0], res[1]._v, player)
return res
@classmethod
@set_instruction_type
def dot_product(cls, x, y):
"""
Secret dot product.
:param x: Iterable of secret values
:param y: Iterable of secret values of same length and type
:rtype: same as inputs
"""
if isinstance(x, cls) and isinstance(y, cls):
assert len(x) == len(y)
res = cls()
matmuls(res, x, y, 1, len(x), 1)
else:
x = list(x)
set_global_vector_size(x[0].size)
res = cls()
dotprods(res, x, y)
reset_global_vector_size()
return res
@classmethod
@set_instruction_type
def row_matrix_mul(cls, row, matrix, res_params=None):
assert len(row) == len(matrix)
size = len(matrix[0])
res = cls(size=size)
dotprods(*sum(([res[j], row, [matrix[k][j] for k in range(len(row))]]
for j in range(size)), []))
return res
@classmethod
@set_instruction_type
def matrix_mul(cls, A, B, n, res_params=None):
assert len(A) % n == 0
assert len(B) % n == 0
size = len(A) * len(B) // n**2
res = cls(size=size)
n_rows = len(A) // n
n_cols = len(B) // n
matmuls(res, A, B, n_rows, n, n_cols)
return res
@staticmethod
def _new(self):
# mirror sfix
return self
@no_doc
def __init__(self, reg_type, val=None, size=None):
if isinstance(val, self.clear_type):
size = val.size
super(_secret, self).__init__(reg_type, val=val, size=size)
@set_instruction_type
@vectorize
def load_int(self, val):
if self.clear_type.in_immediate_range(val):
ldsi(self, val)
else:
self.load_clear(self.clear_type(val))
@vectorize
def load_clear(self, val):
addm(self, self.__class__(0), val)
@set_instruction_type
@read_mem_value
@vectorize
def load_other(self, val):
from Compiler.GC.types import sbits, sbitvec
if isinstance(val, self.clear_type):
self.load_clear(val)
elif isinstance(val, type(self)):
movs(self, val)
elif isinstance(val, sbits):
assert(val.n == self.size)
r = self.get_dabit()
movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit()))
elif isinstance(val, sbitvec):
movs(self, sint.bit_compose(val))
else:
self.load_clear(self.clear_type(val))
@classmethod
def bit_compose(cls, bits):
""" Compose value from bits.
:param bits: iterable of any type convertible to sint """
from Compiler.GC.types import sbits, sbitintvec
if isinstance(bits, sbits):
bits = bits.bit_decompose()
bits = list(bits)
if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits):
if program.use_edabit():
mask = cls.get_edabit(len(bits), strict=True, size=bits[0].n)
else:
tmp = sint(size=bits[0].n)
randoms(tmp, len(bits))
n_overflow_bits = min(program.use_split().bit_length(),
int(program.options.ring) - len(bits))
mask_bits = tmp.bit_decompose(len(bits) + n_overflow_bits,
maybe_mixed=True)
if n_overflow_bits:
overflow = sint.bit_compose(
sint.conv(x) for x in mask_bits[-n_overflow_bits:])
mask = tmp - (overflow << len(bits)), \
mask_bits[:-n_overflow_bits]
else:
mask = tmp, mask_bits
t = sbitintvec.get_type(len(bits) + 1)
masked = t.from_vec(mask[1] + [0]) + t.from_vec(bits + [0])
overflow = masked.v[-1]
masked = cls.bit_compose(x.reveal().to_regint_by_bit() for x in masked.v[:-1])
return masked - mask[0] + (cls(overflow) << len(bits))
else:
return super(_secret, cls).bit_compose(bits)
@set_instruction_type
@read_mem_value
@vectorize
def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False):
res = self.prep_res(other)
cls = type(res)
if isinstance(other, regint):
other = res.clear_type(other)
if isinstance(other, cls):
if reverse:
s_inst(res, other, self)
else:
s_inst(res, self, other)
elif isinstance(other, res.clear_type):
if reverse:
m_inst(res, other, self)
else:
m_inst(res, self, other)
elif isinstance(other, int):
if self.clear_type.in_immediate_range(other):
si_inst(res, self, other)
else:
if reverse:
m_inst(res, res.clear_type(other), self)
else:
m_inst(res, self, res.clear_type(other))
else:
return NotImplemented
return res
def add(self, other):
""" Secret addition.
:param other: any compatible type """
return self.secret_op(other, adds, addm, addsi)
@set_instruction_type
def mul(self, other):
""" Secret multiplication. Either both operands have the same
size or one size 1 for a value-vector multiplication.
:param other: any compatible type """
if isinstance(other, _register) and (1 in (self.size, other.size)) \
and (self.size, other.size) != (1, 1):
x, y = (other, self) if self.size < other.size else (self, other)
if not isinstance(other, _secret):
return y.expand_to_vector(x.size) * x
res = type(self)(size=x.size)
mulrs(res, x, y)
return res
return self.secret_op(other, muls, mulm, mulsi)
def __sub__(self, other):
""" Secret subtraction.
:param other: any compatible type """
return self.secret_op(other, subs, subml, subsi)
def __rsub__(self, other):
return self.secret_op(other, subs, submr, subsfi, True)
__rsub__.__doc__ = __sub__.__doc__
def field_div(self, other):
""" Secret field division.
:param other: any compatible type """
try:
one = self.clear_type(1, size=other.size)
except AttributeError:
one = self.clear_type(1)
return self * one.field_div(other)
@vectorize
def _rfield_div(self, other):
a,b = self.get_random_inverse()
return other * a.field_div((a * self).reveal())
@set_instruction_type
@vectorize
def square(self):
""" Secret square. """
if program.use_square():
res = self.__class__()
sqrs(res, self)
return res
else:
return self * self
@set_instruction_type
def secure_shuffle(self, unit_size=1):
res = type(self)(size=self.size)
secshuffle(res, self, unit_size)
return res
@set_instruction_type
@vectorize
def reveal(self, check=True):
""" Reveal secret value publicly.
:rtype: relevant clear type """
res = self.clear_type()
asm_open(check, res, self)
return res
@set_instruction_type
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
:param player: int
:returns: :py:class:`personal`
"""
mask = self.get_random_input_mask_for(player)
masked = self + mask[0]
res = personal(player, masked.reveal() - mask[1])
return res
@set_instruction_type
@vectorize
def raw_right_shift(self, length):
""" Local right shift in supported protocols.
In integer-like protocols, the output is potentially off by one.
:param length: number of bits
"""
res = type(self)()
shrsi(res, self, length)
return res
def raw_mod2m(self, m):
return self - (self.raw_right_shift(m) << m)
@set_instruction_type
@vectorize
def output(self):
print_reg_plains(self)
[docs]class sint(_secret, _int):
"""
Secret integer in the protocol-specific domain. It supports
operations with :py:class:`sint`, :py:class:`cint`,
:py:class:`regint`, and Python integers. Operations where one of
the operands is an :py:class:`sint` either result in an
:py:class:`sint` or an :py:class:`sintbit`, the latter for
comparisons.
The following operations work as expected in the computation
domain (modulo a prime or a power of two): ``+, -, *``. ``/``
denotes a fixed-point division.
Comparisons operators (``==, !=, <, <=, >, >=``)
assume that the element in the computation domain represents a
signed integer in a restricted range, see below. The same holds
for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and
exponentation (``**``). Modulo only works if the right-hand
operator is a compile-time power of two.
Most non-linear operations require compile-time parameters for bit
length. It defaults to the global parameters set by
:py:meth:`program.set_bit_length`, and its default
is output at the beginning of the compilation.
If the computation domain is modulo a power of two, the
operands will be truncated to the bit length.
Modulo prime, the behaviour is
undefined and potentially insecure if the operands are longer than
the bit length.
See :ref:`nonlinear` for an overview of how non-linear
computation is implemented.
:param val: initialization (sint/cint/regint/int/cgf2n or list
thereof, sbits/sbitvec/sfix, or :py:class:`personal`)
:param size: vector size (int), defaults to 1 or size of list
When converting :py:class:`~Compiler.GC.types.sbits`, the result is a
vector of bits, and when converting
:py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values
with bit length equal the length of the input.
Initializing from a :py:class:`personal` value implies the
relevant party inputting their value securely.
"""
__slots__ = []
instruction_type = 'modp'
clear_type = cint
reg_type = 's'
PreOp = staticmethod(floatingpoint.PreOpL)
PreOR = staticmethod(floatingpoint.PreOR)
@classmethod
def get_type(cls, n):
cls.require_bit_length(n or 0)
return cls
@staticmethod
def require_bit_length(n_bits):
if program.options.ring:
if int(program.options.ring) < n_bits:
raise CompilerError('computation modulus too small')
else:
program.curr_tape.require_bit_length(n_bits)
[docs] @vectorized_classmethod
def get_random_int(cls, bits):
""" Secret random n-bit number according to security model.
:param bits: compile-time integer (int)
:param size: vector size (int, default 1)
"""
if program.use_edabit():
return sint.get_edabit(bits, True)[0]
elif program.use_split() > 2 and program.use_split() < 5:
tmp = sint()
randoms(tmp, bits)
x = tmp.split_to_two_summands(bits, True)
carry = comparison.CarryOutRawLE(x[1][:bits], x[0][:bits])
if program.use_split() > 3:
from .GC.types import sbitint
x = sbitint.full_adder(carry, x[0][bits], x[1][bits])
overflow = sint.conv(x[1]) * 2 + sint.conv(x[0])
else:
overflow = sint.conv(carry) + sint.conv(x[0][bits])
return tmp - (overflow << bits)
res = sint()
comparison.PRandInt(res, bits)
return res
[docs] @vectorized_classmethod
def get_random(cls):
""" Secret random ring element according to security model.
:param size: vector size (int, default 1)
"""
res = sint()
randomfulls(res)
return res
[docs] @vectorized_classmethod
def get_dabit(cls):
""" Bit in arithmetic and binary circuit according to security model """
from Compiler.GC.types import sbits
res = cls(), sbits.get_type(get_global_vector_size())()
dabit(*res)
return res
[docs] @vectorized_classmethod
def get_edabit(cls, n_bits, strict=False):
""" Bits in arithmetic and binary circuit """
""" according to security model """
if not program.use_edabit_for(strict, n_bits):
if program.use_dabit:
a, b = zip(*(sint.get_dabit() for i in range(n_bits)))
return sint.bit_compose(a), b
else:
a = [sint.get_random_bit() for i in range(n_bits)]
return sint.bit_compose(a), a
assert n_bits > 0
program.curr_tape.require_bit_length(n_bits - 2)
whole = cls()
size = get_global_vector_size()
from Compiler.GC.types import sbits, sbitvec
bits = [sbits.get_type(size)() for i in range(n_bits)]
if strict:
sedabit(whole, *bits)
else:
edabit(whole, *bits)
return whole, bits
@staticmethod
@vectorize
def bit_decompose_clear(a, n_bits):
return floatingpoint.bits(a, n_bits)
@vectorized_classmethod
def get_raw_input_from(cls, player):
res = cls()
rawinput(player, res)
return res
[docs] @vectorized_classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of values input by a client.
This uses the triple-based input protocol introduced by
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_ unless
:py:obj:`program.active` is set to false, in which case
it uses random values to mask the clients' input.
:param n: number of inputs (int)
:param client_id: regint
:param size: vector size (default 1)
:returns: list of sint
"""
if program.active:
# send shares of a triple to client
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
else:
triples = [sint.get_random() for i in range(n)]
sint.write_shares_to_socket(client_id, triples, message_type)
received = util.tuplify(cint.read_from_socket(client_id, n))
y = [0] * n
for i in range(n):
y[i] = received[i] - triples[i * 3 if program.active else i]
return y
[docs] @classmethod
def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients.
Uses :py:obj:`program.active` to determine whether to use
triples for active security.
:param clients: client ids (list or array)
:param values: list of sint to reveal
"""
set_global_vector_size(values[0].size)
to_send = []
for value in values:
assert(value.size == values[0].size)
r = sint.get_random(size=value.size)
value += r - r.reveal()
if program.active:
r = sint.get_random()
to_send += [value, r, value * r]
else:
to_send += [value]
if isinstance(clients, Array):
n_clients = clients.length
else:
n_clients = len(clients)
set_global_vector_size(1)
clients = Array.create_from(regint.conv(clients))
reset_global_vector_size()
@library.for_range(n_clients)
def loop_body(i):
sint.write_shares_to_socket(clients[i], to_send)
reset_global_vector_size()
[docs] @vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
""" Receive secret-shared value(s) from client.
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param size: vector size of values (default 1)
:returns: sint (if n=1) or list of sint
"""
res = [cls() for i in range(n)]
readsockets(client_id, get_global_vector_size(), *res)
if n == 1:
return res[0]
else:
return res
[docs] @vectorized_classmethod
def write_to_socket(cls, client_id, values,
message_type=ClientMessageType.NoType):
""" Send a list of shares and MAC shares to a client socket.
:param client_id: regint
:param values: list of sint
"""
writesockets(client_id, message_type, values[0].size, *values)
[docs] @vectorize
def write_fully_to_socket(self, client_id,
message_type=ClientMessageType.NoType):
""" Send full secret to socket """
writesockets(client_id, message_type, self.size, self)
@vectorize
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
""" Send only share to socket """
writesocketshare(client_id, message_type, self.size, self)
[docs] @classmethod
def write_shares_to_socket(cls, client_id, values,
message_type=ClientMessageType.NoType):
""" Send shares of a list of values to a specified client socket.
:param client_id: regint
:param values: list of sint
"""
writesocketshare(client_id, message_type, values[0].size, *values)
[docs] @classmethod
def read_from_file(cls, start, n_items=1, crash_if_missing=True, size=1):
""" Read shares from
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
section <persistence>` for details on the data format.
:param start: starting position in number of shares from beginning (int/regint/cint)
:param n_items: number of items (int)
:param crash_if_missing: crash if file not found (default)
:param size: vector size (int)
:returns: destination for final position, -1 for eof reached, or -2 for file not found (regint)
:returns: list of shares
"""
shares = [cls(size=size) for i in range(n_items)]
stop = regint()
readsharesfromfile(regint.conv(start), stop, *shares)
if crash_if_missing:
library.runtime_error_if(stop == -2, 'Persistence not found')
return stop, shares
[docs] @staticmethod
def write_to_file(shares, position=None):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end). See :ref:`this section <persistence>`
for details on the data format.
:param shares: (list or iterable of sint)
:param position: start position (int/regint/cint),
defaults to end of file
"""
if isinstance(shares, sint):
shares = [shares]
for share in shares:
assert isinstance(share, sint)
assert share.size == shares[0].size
if position is None:
position = -1
writesharestofile(regint.conv(position), *shares)
[docs] @vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls._load_mem(address, ldms, ldmsi)
[docs] def store_in_mem(self, address):
""" Store in memory by public address. """
self._store_in_mem(address, stms, stmsi)
@classmethod
def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None, indices_values=None):
if indices is None:
indices = [regint.inc(i) for i in (n, m, m, l)]
indices_values = [list(range(i)) for i in (n, m, m, l)]
res = cls(size=indices[0].size * indices[3].size)
if isinstance(A, int) and isinstance(B, int):
first_factor_base_addresses = [A]
second_factor_base_addresses = [B]
else:
first_factor_base_addresses = None
second_factor_base_addresses = None
matmulsm(res, regint(A), regint(B), len(indices[0]), len(indices[1]),
len(indices[3]), *(list(indices) + [m, l]),
first_factor_base_addresses=first_factor_base_addresses,
second_factor_base_addresses=second_factor_base_addresses,
indices_values=indices_values)
return res
@vectorize_init
def __init__(self, val=None, size=None):
from .GC.types import sbitvec
if isinstance(val, personal):
size = val._v.size
super(sint, self).__init__('s', size=size)
inputpersonal(size, val.player, self, self.clear_type.conv(val._v))
elif isinstance(val, _fix):
super(sint, self).__init__('s', size=val.v.size)
self.load_other(val.v.round(val.k, val.f,
nearest=val.round_nearest))
elif isinstance(val, sbitvec):
super(sint, self).__init__('s', val=val, size=val[0].n)
else:
super(sint, self).__init__('s', val=val, size=size)
@vectorize
def __neg__(self):
""" Secret negation. """
return 0 - self
@vectorize
def __abs__(self):
""" Secret absolute. Uses global parameters for comparison. """
return (self >= 0).if_else(self, -self)
@read_mem_value
@type_comp
@vectorize
def __lt__(self, other, bit_length=None):
""" Secret comparison (signed).
:param other: sint/cint/regint/int
:param bit_length: bit length of input (default: global bit length)
:return: 0/1 (sintbit) """
res = sintbit()
comparison.LTZ(res, self - other,
(bit_length or program.bit_length) + 1)
return res
@read_mem_value
@type_comp
@vectorize
def __gt__(self, other, bit_length=None):
res = sintbit()
comparison.LTZ(res, other - self,
(bit_length or program.bit_length) + 1)
return res
@read_mem_value
@type_comp
def __le__(self, other, bit_length=None):
return 1 - self.greater_than(other, bit_length)
@read_mem_value
@type_comp
def __ge__(self, other, bit_length=None):
return 1 - self.less_than(other, bit_length)
@read_mem_value
@type_comp
@vectorize
def __eq__(self, other, bit_length=None):
return sintbit.conv(
floatingpoint.EQZ(self - other, bit_length or program.bit_length))
@read_mem_value
@type_comp
def __ne__(self, other, bit_length=None):
return 1 - self.equal(other, bit_length)
less_than = __lt__
greater_than = __gt__
less_equal = __le__
greater_equal = __ge__
equal = __eq__
not_equal = __ne__
for op in __gt__, __le__, __ge__, __eq__, __ne__:
op.__doc__ = __lt__.__doc__
del op
@vectorize
def __mod__(self, modulus):
""" Secret modulo computation.
Uses global parameters for bit length and security.
:param modulus: power of two (int) """
if isinstance(modulus, int):
l = math.log(modulus, 2)
if 2**int(round(l)) == modulus:
return self.mod2m(int(l))
raise NotImplementedError('Modulo only implemented for powers of two.')
[docs] @vectorize
@read_mem_value
def mod2m(self, m, bit_length=None, signed=True):
""" Secret modulo power of two.
:param m: secret or public integer (sint/cint/regint/int)
:param bit_length: bit length of input (default: global bit length)
"""
bit_length = bit_length or program.bit_length
if isinstance(m, int):
if m == 0:
return 0
if m >= bit_length:
return self
res = sint()
comparison.Mod2m(res, self, bit_length, m, signed=signed)
else:
res, pow2 = floatingpoint.Trunc(self, bit_length, m,
compute_modulo=True)
return res
@vectorize
def __rpow__(self, base):
""" Secret power computation. Base must be two.
Uses global parameters for bit length and security. """
if base == 2:
return self.pow2()
else:
return NotImplemented
[docs] @vectorize
def pow2(self, bit_length=None):
""" Secret power of two.
:param bit_length: bit length of input (default: global bit length)
"""
return floatingpoint.Pow2(self, bit_length or program.bit_length)
def __lshift__(self, other, bit_length=None):
""" Secret left shift.
:param other: secret or public integer (sint/cint/regint/int)
:param bit_length: bit length of input (default: global bit length)
"""
return self * util.pow2_value(other, bit_length)
@vectorize
@read_mem_value
def __rshift__(self, other, bit_length=None, signed=True):
""" Secret right shift.
:param other: secret or public integer (sint/cint/regint/int)
:param bit_length: bit length of input (default: global bit length)
"""
bit_length = bit_length or program.bit_length
if isinstance(other, int):
if other == 0:
return self
res = sint()
comparison.Trunc(res, self, bit_length, other, signed)
return res
elif isinstance(other, sint):
return floatingpoint.Trunc(self, bit_length, other)
else:
return floatingpoint.Trunc(self, bit_length, sint(other))
left_shift = __lshift__
right_shift = __rshift__
def __rlshift__(self, other):
""" Secret left shift.
Bit length of :py:obj:`self` uses global value.
:param other: secret or public integer (sint/cint/regint/int) """
return other * 2**self
@vectorize
def __rrshift__(self, other):
""" Secret right shift.
:param other: secret or public integer (sint/cint/regint/int) of globale bit length if secret """
return floatingpoint.Trunc(other, program.bit_length, self)
[docs] @vectorize
def bit_decompose(self, bit_length=None, maybe_mixed=False):
""" Secret bit decomposition. """
if bit_length == 0:
return []
bit_length = bit_length or program.bit_length
return program.non_linear.bit_dec(self, bit_length, bit_length,
maybe_mixed)
def TruncMul(self, other, k, m, nearest=False):
return (self * other).round(k, m, nearest, signed=True)
def TruncPr(self, k, m, signed=True):
return floatingpoint.TruncPr(self, k, m, signed=signed)
[docs] @vectorize
def round(self, k, m, nearest=False, signed=False):
""" Truncate and maybe round secret :py:obj:`k`-bit integer
by :py:obj:`m` bits. :py:obj:`m` can be secret if
:py:obj:`nearest` is false, in which case the truncation will be
exact. For public :py:obj:`m`, :py:obj:`nearest` chooses
between nearest rounding (rounding half up) and probablistic
truncation.
:param k: int
:param m: secret or compile-time integer (sint/int)
:param nearest: bool
:param signed: bool """
secret = isinstance(m, sint)
if nearest:
if secret:
raise NotImplementedError()
return comparison.TruncRoundNearest(self, k, m,
signed=signed)
else:
if secret:
return floatingpoint.Trunc(self, k, m)
return self.TruncPr(k, m, signed=signed)
def __truediv__(self, other):
""" Secret fixed-point division.
:param other: any compatible type """
if isinstance(other, sint):
return other.__rtruediv__(self)
try:
return sfix._new(self) / cfix._new(cint(other), f=sfix.f, k=sfix.k)
except:
return NotImplemented
def __rtruediv__(self, other):
return sfix._new(other) / sfix._new(self)
[docs] @vectorize
def int_div(self, other, bit_length=None):
""" Secret integer division. Note that the domain bit length
needs to be about four times the bit length.
:param other: sint
:param bit_length: bit length of input (default: global bit length)
"""
k = bit_length or program.bit_length
tmp = library.IntDiv(self, other, k)
res = type(self)()
comparison.Trunc(res, tmp, 2 * k, k, signed=True)
return res
[docs] @vectorize
def int_mod(self, other, bit_length=None):
""" Secret integer modulo. Note that the domain bit length
needs to be about four times the bit length.
:param other: sint
:param bit_length: bit length of input (default: global bit length)
"""
return self - other * self.int_div(other, bit_length=bit_length)
def trunc_zeros(self, n_zeros, bit_length=None, signed=True):
bit_length = bit_length or program.bit_length
return comparison.TruncZeros(self, bit_length, n_zeros, signed)
def split_to_n_summands(self, length, n):
comparison.require_ring_size(length, 'splitting')
from .GC.types import sbits
from .GC.instructions import split
columns = [[sbits.get_type(self.size)()
for i in range(n)] for i in range(length)]
split(n, self, *sum(columns, []))
return columns
def split_to_two_summands(self, length, get_carry=False):
n = program.use_split()
assert n
columns = self.split_to_n_summands(length, n)
return _bitint.wallace_tree_without_finish(columns, get_carry)
[docs] def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
:param player: public integer (int/regint/cint)
:returns: :py:class:`personal`
"""
if not util.is_constant(player):
secret_mask = sint(size=self.size)
player_mask = cint(size=self.size)
inputmaskreg(secret_mask, player_mask,
regint.conv(player).expand_to_vector(self.size))
return personal(player,
(self + secret_mask).reveal(False) - player_mask)
else:
res = personal(player, self.clear_type(size=self.size))
privateoutput(self.size, player, res._v, self)
return res
[docs] def private_division(self, divisor, active=None, dividend_length=None,
divisor_length=None):
""" Private integer division as per `Veugen and Abspoel
<https://doi.org/10.2478/popets-2021-0073>`_
:param divisor: public (cint/regint) or personal value thereof
:param active: whether to check on the party knowing the
divisor (active security)
:param dividend_length: bit length of the dividend (default:
global bit length)
:param dividend_length: bit length of the divisor (default:
global bit length)
"""
d = divisor
l = divisor_length or program.bit_length
m = dividend_length or program.bit_length
sigma = program.security
min_length = m + l + 2 * sigma + 1
if program.options.ring:
comparison.require_ring_size(min_length, 'private division')
else:
program.curr_tape.require_bit_length(min_length)
r = sint.get_random_int(l + sigma)
r_prime = sint.get_random_int(m + sigma)
r_pprime = sint.get_random_int(l + sigma)
d_shared = sint(d)
h = (r + (r_prime << (l + sigma))) * d_shared
z_shared = ((self << (l + sigma)) + h + r_pprime)
z = z_shared.reveal_to(0)
if active is None:
active = program.active
if active:
z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)]
check = [(x * (1 - x)).reveal() == 0 for x in z_prime]
z_pp = [sint(x) for x in (z % d).bit_decompose(l)]
check += [(x * (1 - x)).reveal() == 0 for x in z_pp]
library.runtime_error_if(sum(check) != len(check),
'private division')
z_pp = sint.bit_compose(z_pp)
beta1 = z_pp.less_than(d_shared, l)
beta2 = z_shared - sint.bit_compose(z_prime) * d_shared - z_pp
library.runtime_error_if(beta1.reveal() != 1, 'private div')
library.runtime_error_if(beta2.reveal() != 0, 'private div')
y_prime = sint.bit_compose(z_prime[:l + sigma])
y = sint.bit_compose(z_prime[l + sigma:])
else:
program.semi_honest()
y = sint(z // (d << (l + sigma)))
y_prime = sint((z // d) % (2 ** (l + sigma)))
b = r.greater_than(y_prime, l + sigma)
w = y - b - r_prime
return w
@staticmethod
def get_secure_shuffle(n):
res = regint()
gensecshuffle(res, n)
return res
@read_mem_value
def secure_permute(self, shuffle, unit_size=1, reverse=False):
res = sint(size=self.size)
applyshuffle(res, self, unit_size, shuffle, reverse)
return res
def inverse_permutation(self):
if program.use_invperm():
# If enabled, we use the low-level INVPERM instruction.
# This instruction has only been implemented for a semi-honest two-party environement.
res = sint(size=self.size)
inverse_permutation(res, self)
else:
shuffle = sint.get_secure_shuffle(len(self))
shuffled = self.secure_permute(shuffle).reveal()
idx = Array.create_from(shuffled)
res = Array.create_from(sint(regint.inc(len(self))))
res.secure_permute(shuffle, reverse=False)
res.assign_slice_vector(idx, res.get_vector())
library.break_point()
res = res.get_vector()
return res
[docs] @vectorize
def prefix_sum(self):
""" Prefix sum. """
res = sint()
prefixsums(res, self)
return res
def sum(self):
res = type(self)(size=1)
picks(res, self.prefix_sum(), len(self) - 1, 0)
return res
def _expand_to_vector(self, size):
res = type(self)(size=size)
picks(res, self, 0, 0)
return res
def copy_from_part(self, source, base, size):
picks(self, source, base, 1)
def get_reverse_vector(self):
res = type(self)(size=self.size)
picks(res, self, self.size - 1, -1)
return res
def get_vector(self, base=0, size=None, skip=1):
if size is None:
size = len(self) - base
if base == 0 and size == len(self):
return self
assert base + size <= len(self)
res = type(self)(size=size)
picks(res, self, base, skip)
return res
@classmethod
def concat(cls, parts):
parts = list(parts)
res = cls(size=sum(len(part) for part in parts))
args = sum(([len(part), part] for part in parts), [])
concats(res, *args)
return res
@classmethod
def zip(cls, *parts):
res = cls(size=sum(len(part) for part in parts))
zips(res, *parts)
return res
[docs]class sintbit(sint):
""" :py:class:`sint` holding a bit, supporting binary operations
(``&, |, ^``). """
@classmethod
def prep_res(cls, other):
return sint()
@read_mem_value
def load_other(self, other):
if isinstance(other, sint):
movs(self, other)
else:
super(sintbit, self).load_other(other)
@vectorize
def __and__(self, other):
if isinstance(other, sintbit):
res = sintbit()
muls(res, self, other)
return res
elif util.is_zero(other):
return 0
elif util.is_one(other):
return self
else:
return NotImplemented
@vectorize
def __or__(self, other):
if isinstance(other, sintbit):
res = sintbit()
adds(res, self, other - self * other)
return res
elif util.is_zero(other):
return self
elif util.is_one(other):
return 1
else:
return NotImplemented
@vectorize
def __xor__(self, other):
if isinstance(other, sintbit):
res = sintbit()
adds(res, self, other - 2 * self * other)
return res
elif util.is_zero(other):
return self
elif util.is_one(other):
res = sintbit()
submr(res, cint(1), self)
return res
else:
return NotImplemented
@vectorize
def __rsub__(self, other):
if util.is_one(other):
res = sintbit()
subsfi(res, self, 1)
return res
else:
return super(sintbit, self).__rsub__(other)
__rand__ = __and__
__rxor__ = __xor__
__ror__ = __or__
[docs]class sgf2n(_secret, _gf2n):
r"""
Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A
number operators are supported (``+, -, *, /, **, ^, ~, ==, !=,
<<``), :py:class:`sgf2n`. Operators generally work with
cgf2n/regint/cint/int, except ``**, <<``, which require a
compile-time integer. ``/`` refers to field division. ``*, /,
**`` refer to field multiplication and division.
:param val: initialization (sgf2n/cgf2n/regint/int/cint or list thereof)
:param size: vector size (int), defaults to 1 or size of list
"""
__slots__ = []
instruction_type = 'gf2n'
clear_type = cgf2n
reg_type = 'sg'
long_one = staticmethod(lambda: 1)
@classmethod
def get_raw_input_from(cls, player):
res = cls()
grawinput(player, res)
return res
def add(self, other):
r""" Secret :math:`\mathrm{GF}(2^n)` addition (XOR).
:param other: sg2fn/cgf2n/regint/int """
if isinstance(other, sgf2nint):
return NotImplemented
else:
return super(sgf2n, self).add(other)
def mul(self, other):
r""" Secret :math:`\mathrm{GF}(2^n)` multiplication.
:param other: sg2fn/cgf2n/regint/int """
if isinstance(other, (sgf2nint)):
return NotImplemented
else:
return super(sgf2n, self).mul(other)
[docs] @vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls._load_mem(address, gldms, gldmsi)
[docs] def store_in_mem(self, address):
""" Store in memory by public address. """
self._store_in_mem(address, gstms, gstmsi)
@vectorize_init
def __init__(self, val=None, size=None):
super(sgf2n, self).__init__('sg', val=val, size=size)
def __neg__(self):
""" Identity. """
return self
__truediv__ = _secret.field_div
__rtruediv__ = _secret._rfield_div
@vectorize
def __invert__(self):
""" Secret bit-wise inversion. """
return self ^ cgf2n(2**program.galois_length - 1)
def __xor__(self, other):
""" Secret bit-wise XOR.
:param other: sg2fn/cgf2n/regint/int """
if is_zero(other):
return self
else:
return super(sgf2n, self).add(other)
__rxor__ = __xor__
@vectorize
def __and__(self, other):
""" Secret bit-wise AND.
:param other: sg2fn/cgf2n/regint/int """
if isinstance(other, int):
other_bits = [(other >> i) & 1 \
for i in range(program.galois_length)]
else:
other_bits = other.bit_decompose()
self_bits = self.bit_decompose()
return sum((x * y) << i \
for i,(x,y) in enumerate(zip(self_bits, other_bits)))
__rand__ = __and__
@vectorize
def __lshift__(self, other):
""" Secret left shift py public value.
:param other: regint/cint/int """
return self * cgf2n(1 << other)
[docs] @vectorize
def right_shift(self, other, bit_length=None):
r""" Secret right shift by public value:
:param other: compile-time (int)
:param bit_length: number of bits of :py:obj:`self` (defaults to :math:`\mathrm{GF}(2^n)` bit length) """
bits = self.bit_decompose(bit_length)
return sum(b << i for i,b in enumerate(bits[other:]))
[docs] def equal(self, other, bit_length=None, expand=1):
""" Secret comparison.
:param other: sgf2n/cgf2n/regint/int
:return: 0/1 (sgf2n) """
bits = [1 - bit for bit in (self - other).bit_decompose(bit_length)][::expand]
while len(bits) > 1:
bits.insert(0, bits.pop() * bits.pop())
return bits[0]
[docs] def not_equal(self, other, bit_length=None):
""" Secret comparison. """
return 1 - self.equal(other, bit_length)
not_equal.__doc__ = equal.__doc__
__eq__ = equal
__ne__ = not_equal
[docs] @vectorize
def bit_decompose(self, bit_length=None, step=1):
""" Secret bit decomposition.
:param bit_length: number of bits
:param step: use every :py:obj:`step`-th bit
:return: list of sgf2n """
if bit_length == 0:
return []
bit_length = bit_length or program.galois_length
random_bits = [self.get_random_bit() \
for i in range(0, bit_length, step)]
one = cgf2n(1)
masked = sum([b * (one << (i * step))
for i,b in enumerate(random_bits)], self).reveal(
check=False)
masked_bits = masked.bit_decompose(bit_length,step=step)
return [m + r for m,r in zip(masked_bits, random_bits)]
@vectorize
def bit_decompose_embedding(self):
random_bits = [self.get_random_bit() \
for i in range(8)]
one = cgf2n(1)
wanted_positions = [0, 5, 10, 15, 20, 25, 30, 35]
masked = sum([b * (one << wanted_positions[i])
for i,b in enumerate(random_bits)], self).reveal(
check=False)
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
for t in (sint, sgf2n):
t.basic_type = t
t.default_type = t
sint.bit_type = sintbit
sgf2n.bit_type = sgf2n
class _bitint(Tape._no_truth):
bits = None
log_rounds = False
linear_rounds = False
comp_result = staticmethod(lambda x: x)
@staticmethod
def half_adder(a, b):
return a.half_adder(b)
@classmethod
def bit_adder(cls, a, b, carry_in=0, get_carry=False):
a, b = list(a), list(b)
a += [0] * (len(b) - len(a))
b += [0] * (len(a) - len(b))
return cls.bit_adder_selection(a, b, carry_in=carry_in,
get_carry=get_carry)
@classmethod
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
if cls.log_rounds:
return cls.carry_lookahead_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
elif cls.linear_rounds:
return cls.ripple_carry_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
else:
return cls.carry_select_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
@classmethod
def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0,
get_carry=False):
lower = []
a, b = a[:], b[:]
for (ai, bi) in zip(a[:], b[:]):
if is_zero(ai) or is_zero(bi):
lower.append(ai + bi)
a.pop(0)
b.pop(0)
else:
break
carries = cls.get_carries(a, b, fewer_inv=fewer_inv, carry_in=carry_in)
res = lower + cls.sum_from_carries(a, b, carries)
if get_carry:
res += [carries[-1]]
return res
@classmethod
def get_carries(cls, a, b, fewer_inv=False, carry_in=0):
d = [cls.half_adder(ai, bi) for (ai,bi) in zip(a,b)]
carry = floatingpoint.carry
if fewer_inv:
pre_op = floatingpoint.PreOpL2
else:
pre_op = floatingpoint.PreOpL
if d:
carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1]
else:
carries = []
return carries
@staticmethod
def sum_from_carries(a, b, carries):
return [ai.bit_xor(bi).bit_xor(carry) \
for (ai, bi, carry) in zip(a, b, carries)]
@classmethod
def carry_select_adder(cls, a, b, get_carry=False, carry_in=0):
a += [0] * (len(b) - len(a))
b += [0] * (len(a) - len(b))
n = len(a)
for m in range(100):
if sum(range(m + 1)) + 1 >= n:
break
for k in range(m, -1, -1):
if sum(range(m, k - 1, -1)) + 1 >= n:
break
blocks = list(range(m, k, -1))
blocks.append(n - sum(blocks))
blocks.reverse()
#print 'blocks:', blocks
if len(blocks) > 1 and blocks[0] > blocks[1]:
raise Exception('block size not increasing:', blocks)
if sum(blocks) != n:
raise Exception('blocks not summing up: %s != %s' % \
(sum(blocks), n))
res = []
carry = carry_in
cin_one = util.long_one(a + b)
for m in blocks:
aa = a[:m]
bb = b[:m]
a = a[m:]
b = b[m:]
cc = [cls.ripple_carry_adder(aa, bb, i) for i in (0, cin_one)]
for i in range(m):
res.append(util.if_else(carry, cc[1][i], cc[0][i]))
carry = util.if_else(carry, cc[1][m], cc[0][m])
if get_carry:
res += [carry]
return res
@classmethod
def ripple_carry_adder(cls, a, b, carry_in=0, get_carry=True):
carry = carry_in
res = []
for aa, bb in zip(a, b):
cc, carry = cls.full_adder(aa, bb, carry)
res.append(cc)
if get_carry:
res.append(carry)
return res
@staticmethod
def full_adder(a, b, carry):
s = a ^ b
return s ^ carry, a ^ (s & (carry ^ a))
@staticmethod
def bit_comparator(a, b):
long_one = util.long_one(a + b)
op = lambda y,x,*args: (util.if_else(x[1], x[0], y[0]), \
util.if_else(x[1], long_one, y[1]))
return floatingpoint.KOpL(op, [(bi, ai + bi) for (ai,bi) in zip(a,b)])
@classmethod
def bit_less_than(cls, a, b):
x, not_equal = cls.bit_comparator(a, b)
return util.if_else(not_equal, x, 0)
@staticmethod
def get_highest_different_bits(a, b, index):
diff = [ai + bi for (ai,bi) in reversed(list(zip(a,b)))]
preor = floatingpoint.PreOR(diff, raw=True)
highest_diff = [x - y for (x,y) in reversed(list(zip(preor, [0] + preor)))]
raw = sum(map(operator.mul, highest_diff, (a,b)[index]))
return raw.bit_decompose()[0]
def add(self, other):
if type(other) == self.bin_type:
raise CompilerError('Unclear addition')
a = self.bit_decompose()
b = util.bit_decompose(other, self.n_bits)
return self.compose(self.bit_adder(a, b))
@ret_cisc
def mul(self, other):
if type(other) == self.bin_type:
raise CompilerError('Unclear multiplication')
self_bits = self.bit_decompose()
if isinstance(other, int):
other_bits = util.bit_decompose(other, self.n_bits)
bit_matrix = [[x * y for y in self_bits] for x in other_bits]
else:
try:
other_bits = other.bit_decompose()
if len(other_bits) == 1:
return type(self)(other_bits[0] * self)
if len(self_bits) != len(other_bits):
raise NotImplementedError('Multiplication of different lengths')
except AttributeError:
pass
try:
other = self.bin_type(other)
except CompilerError:
return NotImplemented
bit_matrix = self.get_bit_matrix(self_bits, other)
return self.compose(self.wallace_tree_from_matrix(bit_matrix, False))
@classmethod
def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True):
columns = [[_f for _f in (bit_matrix[j][i-j] \
for j in range(min(len(bit_matrix), i + 1))) \
if not is_zero(_f)] \
for i in range(len(bit_matrix[0]))]
return cls.wallace_tree_from_columns(columns, get_carry)
@classmethod
def wallace_tree_without_finish(cls, columns, get_carry=True):
self = cls
columns = [col[:] for col in columns]
while max(len(c) for c in columns) > 2:
new_columns = [[] for i in range(len(columns) + 1)]
for i,col in enumerate(columns):
while len(col) > 2:
s, carry = self.full_adder(*(col.pop() for i in range(3)))
new_columns[i].append(s)
new_columns[i+1].append(carry)
if len(col) == 2:
s, carry = self.half_adder(*(col.pop() for i in range(2)))
new_columns[i].append(s)
new_columns[i+1].append(carry)
else:
new_columns[i].extend(col)
if get_carry:
columns = new_columns
else:
columns = new_columns[:-1]
for col in columns:
col.extend([0] * (2 - len(col)))
return tuple(list(x) for x in zip(*columns))
@classmethod
def wallace_tree_from_columns(cls, columns, get_carry=True):
summands = cls.wallace_tree_without_finish(columns, get_carry)
return cls.bit_adder(*summands)
@classmethod
def wallace_tree(cls, rows):
return cls.wallace_tree_from_columns([list(x) for x in zip(*rows)])
@classmethod
def wallace_reduction(cls, a, b, c, get_carry=True):
assert len(a) == len(b) == len(c)
tmp = zip(*(cls.full_adder(*x) for x in zip(a, b, c)))
sums, carries = (list(x) for x in tmp)
carries = [0] + carries
if get_carry:
sums += [0]
else:
del carries[-1]
return sums, carries
def expand(self, other):
a = self.bit_decompose()
b = util.bit_decompose(other, self.n_bits)
return a, b
def __sub__(self, other):
if type(other) == sgf2n:
raise CompilerError('Unclear subtraction')
from util import bit_not, bit_and, bit_xor
try:
a, b = self.expand(other)
except:
return NotImplemented
n = 1
for x in (a + b):
try:
n = x.n
break
except:
pass
d = [(bit_not(bit_xor(ai, bi), n), bit_and(bit_not(ai, n), bi))
for (ai,bi) in zip(a,b)]
borrow = lambda y,x,*args: \
(bit_and(x[0], y[0]), util.OR(x[1], bit_and(x[0], y[1])))
borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1]
return self.compose(reduce(util.bit_xor, (ai, bi, borrow)) \
for (ai,bi,borrow) in zip(a,b,borrows))
def __rsub__(self, other):
raise NotImplementedError()
def __truediv__(self, other):
raise NotImplementedError()
def __truerdiv__(self, other):
raise NotImplementedError()
def __lshift__(self, other):
return self.compose(([0] * other + self.bit_decompose())[:self.n_bits])
def __rshift__(self, other):
return self.compose(self.bit_decompose()[other:])
def bit_decompose(self, n_bits=None):
if self.bits is None:
self.bits = self.force_bit_decompose(self.n_bits)
if n_bits is None:
return self.bits[:]
else:
return self.bits[:n_bits] + [self.fill_bit()] * (n_bits - self.n_bits)
def fill_bit(self):
return self.bits[-1]
@staticmethod
def prep_comparison(a, b):
a[-1], b[-1] = b[-1], a[-1]
def comparison(self, other, const_rounds=False, index=None):
a, b = self.expand(other)
self.prep_comparison(a, b)
if const_rounds:
return self.get_highest_different_bits(a, b, index)
else:
return self.bit_comparator(a, b)
def __lt__(self, other):
if program.options.comparison == 'log':
x, not_equal = self.comparison(other)
res = util.if_else(not_equal, x, 0)
else:
res = self.comparison(other, True, 1)
return self.comp_result(res)
def __le__(self, other):
if program.options.comparison == 'log':
x, not_equal = self.comparison(other)
res = util.if_else(not_equal, x, x.long_one())
else:
res = self.comparison(other, True, 0).bit_not()
return self.comp_result(res)
def __ge__(self, other):
return (self < other).bit_not()
def __gt__(self, other):
return (self <= other).bit_not()
def __eq__(self, other, bit_length=None):
diff = self ^ other
diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]]
return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y),
diff_bits))
def __ne__(self, other):
return (self == other).bit_not()
equal = __eq__
def __neg__(self):
bits = self.bit_decompose()
n = 1
for b in bits:
try:
n = x.n
break
except:
pass
return 1 + self.compose(util.bit_not(b, n) for b in bits)
def __abs__(self):
return util.if_else(self.bit_decompose()[-1], -self, self)
less_than = lambda self, other, *args, **kwargs: self < other
greater_than = lambda self, other, *args, **kwargs: self > other
less_equal = lambda self, other, *args, **kwargs: self <= other
greater_equal = lambda self, other, *args, **kwargs: self >= other
equal = lambda self, other, *args, **kwargs: self == other
not_equal = lambda self, other, *args, **kwargs: self != other
class intbitint(_bitint, sint):
@staticmethod
def full_adder(a, b, carry):
s = a.bit_xor(b)
return s.bit_xor(carry), util.if_else(s, carry, a)
@staticmethod
def sum_from_carries(a, b, carries):
return [a[i] + b[i] + carries[i] - 2 * carries[i + 1] \
for i in range(len(a))]
@classmethod
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
if cls.linear_rounds:
return cls.ripple_carry_adder(a, b, carry_in=carry_in)
# experimental cut-off with dead code elimination
elif len(a) < 122 or cls.log_rounds:
return cls.carry_lookahead_adder(a, b, carry_in=carry_in,
get_carry=get_carry)
else:
return cls.carry_select_adder(a, b, carry_in=carry_in)
class sgf2nint(_bitint, sgf2n):
bin_type = sgf2n
@classmethod
def compose(cls, bits):
bits = list(bits)
if len(bits) > cls.n_bits:
raise CompilerError('Too many bits')
res = cls()
res.bits = bits + [0] * (cls.n_bits - len(bits))
gmovs(res, sum(b << i for i,b in enumerate(bits)))
return res
@staticmethod
def get_bit_matrix(self_bits, other):
products = [x * other for x in self_bits]
return [util.bit_decompose(x, len(self_bits)) for x in products]
def load_int(self, other):
if -2**(self.n_bits-1) <= other < 2**(self.n_bits-1):
self.bin_type.load_int(self, other + 2**self.n_bits \
if other < 0 else other)
else:
raise CompilerError('Invalid signed %d-bit integer: %d' % \
(self.n_bits, other))
@vectorize
def load_other(self, other):
if isinstance(other, sgf2nint):
gmovs(self, self.compose(other.bit_decompose(self.n_bits)))
elif isinstance(other, sgf2n):
gmovs(self, other)
else:
gaddm(self, sgf2n(0), cgf2n(other))
def force_bit_decompose(self, n_bits=None):
return sgf2n(self).bit_decompose(n_bits)
class sgf2nuint(sgf2nint):
def load_int(self, other):
if 0 <= other < 2**self.n_bits:
sgf2n.load_int(self, other)
else:
raise CompilerError('Invalid unsigned %d-bit integer: %d' % \
(self.n_bits, other))
def fill_bit(self):
return 0
@staticmethod
def prep_comparison(a, b):
pass
class sgf2nuint32(sgf2nuint):
n_bits = 32
class sgf2nint32(sgf2nint):
n_bits = 32
def get_sgf2nint(n):
class sgf2nint_spec(sgf2nint):
n_bits = n
#sgf2nint_spec.__name__ = 'sgf2unint' + str(n)
return sgf2nint_spec
def get_sgf2nuint(n):
class sgf2nuint_spec(sgf2nint):
n_bits = n
#sgf2nuint_spec.__name__ = 'sgf2nuint' + str(n)
return sgf2nuint_spec
class sgf2nfloat(sgf2n):
@classmethod
def set_precision(cls, vlen, plen):
cls.vlen = vlen
cls.plen = plen
class v_type(sgf2nuint):
n_bits = 2 * vlen + 1
class p_type(sgf2nint):
n_bits = plen
class pdiff_type(sgf2nuint):
n_bits = plen
cls.v_type = v_type
cls.p_type = p_type
cls.pdiff_type = pdiff_type
def __init__(self, val, p=None, z=None, s=None):
super(sgf2nfloat, self).__init__()
if p is None and type(val) == sgf2n:
bits = val.bit_decompose(self.vlen + self.plen + 1)
self.v = self.v_type.compose(bits[:self.vlen])
self.p = self.p_type.compose(bits[self.vlen:-1])
self.s = bits[-1]
self.z = util.tree_reduce(operator.mul, (1 - b for b in self.v.bits))
else:
if p is None:
v, p, z, s = sfloat.convert_float(val, self.vlen, self.plen)
# correct sfloat
p += self.vlen - 1
v_bits = util.bit_decompose(v, self.vlen)
p_bits = util.bit_decompose(p, self.plen)
self.v = self.v_type.compose(v_bits)
self.p = self.p_type.compose(p_bits)
self.z = z
self.s = s
else:
self.v, self.p, self.z, self.s = val, p, z, s
v_bits = val.bit_decompose()[:self.vlen]
p_bits = p.bit_decompose()[:self.plen]
gmovs(self, util.bit_compose(v_bits + p_bits + [self.s]))
def add(self, other):
a = self.p < other.p
b = self.p == other.p
c = self.v < other.v
other_dominates = (b.if_else(c, a))
pmax, pmin = a.cond_swap(self.p, other.p, self.p_type)
vmax, vmin = other_dominates.cond_swap(self.v, other.v, self.v_type)
s3 = self.s ^ other.s
pdiff = self.pdiff_type(pmax - pmin)
d = self.vlen < pdiff
pow_delta = util.pow2(d.if_else(0, pdiff).bit_decompose(util.log2(self.vlen)))
v3 = vmax
v4 = self.v_type(sgf2n(vmax) * pow_delta) + self.v_type(s3.if_else(-vmin, vmin))
v = self.v_type(sgf2n(d.if_else(v3, v4) << self.vlen) / pow_delta)
v >>= self.vlen - 1
h = floatingpoint.PreOR(v.bits[self.vlen+1::-1])
tmp = sum(util.if_else(b, 0, 1 << i) for i,b in enumerate(h))
pow_p0 = 1 + self.v_type(tmp)
v = (v * pow_p0) >> 2
p = pmax - sum(self.p_type.compose([1 - b]) for b in h) + 1
v = self.z.if_else(other.v, other.z.if_else(self.v, v))
z = v == 0
p = z.if_else(0, self.z.if_else(other.p, other.z.if_else(self.p, p)))
s = other_dominates.if_else(other.s, self.s)
s = self.z.if_else(other.s, other.z.if_else(self.s, s))
return sgf2nfloat(v, p, z, s)
def mul(self, other):
v = (self.v * other.v) >> (self.vlen - 1)
b = v.bits[self.vlen]
v = b.if_else(v >> 1, v)
p = self.p + other.p + self.p_type.compose([b])
s = self.s + other.s
z = util.or_op(self.z, other.z)
return sgf2nfloat(v, p, z, s)
sgf2nfloat.set_precision(24, 8)
def parse_type(other, k=None, f=None):
# converts type to cfix/sfix depending on the case
if isinstance(other, cfix.scalars):
return cfix(other, k=k, f=f)
elif isinstance(other, cint):
tmp = cfix(k=k, f=f)
tmp.load_int(other)
return tmp
elif isinstance(other, sint):
tmp = sfix(k=k, f=f)
tmp.load_int(other)
return tmp
elif isinstance(other, sfloat):
tmp = sfix(other, k=k, f=f)
return tmp
else:
return other
[docs]class cfix(_number, _structure):
"""
Clear fixed-point number represented as clear integer. It supports
basic arithmetic (``+, -, *, /``), returning either
:py:class:`cfix` if the other operand is public
(cfix/regint/cint/int) or :py:class:`sfix` if the other operand is
an sfix. It also support comparisons (``==, !=, <, <=, >, >=``),
returning either :py:class:`regint` or :py:class:`sbitint`.
Similarly to :py:class:`Compiler.types.cint`, this type is
restricted to arithmetic circuits due to the fact that only
arithmetic protocols offer communication-less public-private
integer operations.
:param v: cfix/float/int
"""
__slots__ = ['value', 'f', 'k']
reg_type = 'c'
scalars = (int, float, regint, cint)
[docs] @classmethod
def set_precision(cls, f, k = None):
""" Set the precision of the integer representation.
The initial defaults are chosen to
allow the best optimization of probabilistic truncation in
computation modulo 2^64 (2*k < 64). Generally, 2*k must be at
most the integer length for rings and at most m-s-1 for
computation modulo an m-bit prime and statistical security s
(default 40).
:param f: bit length of decimal part (initial default 16)
:param k: whole bit length of fixed point, defaults to twice :py:obj:`f` if not given (initial default 31)
"""
cls.f = f
if k is None:
cls.k = 2 * f
else:
cls.k = k
[docs] @vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls._new(cint.load_mem(address))
[docs] @vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
"""
Receive clear fixed-point value(s) from client. The client needs
to convert the values to the right integer representation.
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param: vector size (int)
:returns: cfix (if n=1) or list of cfix
"""
cint_inputs = cint.read_from_socket(client_id, n)
if n == 1:
return cfix._new(cint_inputs)
else:
return list(map(cfix._new, cint_inputs))
[docs] @classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of clear fixed-point values to a client
(represented as clear integers).
:param client_id: Client id (regint)
:param values: list of cint
"""
for value in values:
assert(value.size == values[0].size)
def cfix_to_cint(fix_val):
return cint(fix_val.v)
cint_values = list(map(cfix_to_cint, values))
writesocketc(client_id, message_type, values[0].size, *cint_values)
@staticmethod
def malloc(size, creator_tape=None):
return program.malloc(size, cint, creator_tape=creator_tape)
@classmethod
def free(cls, addr):
return cint.free(addr)
@staticmethod
def n_elements():
return 1
@classmethod
def from_int(cls, other):
res = cls()
res.load_int(other)
return res
@classmethod
def _new(cls, other, k=None, f=None):
assert not isinstance(other, (list, tuple))
res = cls(k=k, f=f)
res.v = cint.conv(other)
return res
@staticmethod
def int_rep(v, f, k=None):
if isinstance(v, regint):
v = cint(v)
res = v * (2 ** f)
try:
res = int(round(res))
if k and res >= 2 ** (k - 1) or res < -2 ** (k - 1):
limit = 2 ** (k - f - 1)
raise CompilerError(
'Value out of fixed-point range [-%d, %d). '
'Use `sfix.set_precision(f, k)` with k being at least f+%d'
% (limit, limit, res.bit_length() - f + 1))
except TypeError:
pass
return res
@vectorize_init
@read_mem_value
def __init__(self, v=None, k=None, f=None, size=None):
f = self.f if f is None else f
k = self.k if k is None else k
self.f = f
self.k = k
if isinstance(v, cfix.scalars):
v = self.int_rep(v, f=f, k=k)
self.v = cint(v, size=size)
elif isinstance(v, cfix):
self.v = v.v
elif v is None:
self.v = cint(0)
else:
raise CompilerError('cannot initialize cfix with %s' % v)
def __iter__(self):
for x in self.v:
yield self._new(x, self.k, self.f)
def __len__(self):
return len(self.v)
def __getitem__(self, index):
if isinstance(index, slice):
return [self._new(x, k=self.k, f=self.f) for x in self.v[index]]
return self._new(self.v[index], k=self.k, f=self.f)
def get_vector(self):
return self
@vectorize
def load_int(self, v):
self.v = cint(v) * (2 ** self.f)
@classmethod
def conv(cls, other):
if isinstance(other, cls):
return other
else:
return cls(other)
[docs] def store_in_mem(self, address):
""" Store in memory by public address. """
self.v.store_in_mem(address)
@property
def size(self):
return self.v.size
def sizeof(self):
return self.size * 4
@read_mem_value
def parse_type(self, other):
res = parse_type(other, f=self.f, k=self.k)
# check attributes if available
try:
assert res.k == self.k
assert res.f == self.f
except AttributeError:
pass
return res
@vectorize
def add(self, other):
""" Clear fixed-point addition.
:param other: cfix/cint/regint/int """
other = self.parse_type(other)
if isinstance(other, cfix):
return cfix._new(self.v + other.v, k=self.k, f=self.f)
else:
return NotImplemented
def mul(self, other):
""" Clear fixed-point multiplication.
:param other: cfix/cint/regint/int/sint """
if isinstance(other, sint):
return sfix._new(self.v * other, k=self.k, f=self.f)
if isinstance(other, (int, regint, cint)):
return cfix._new(self.v * cint(other), k=self.k, f=self.f)
other = self.parse_type(other)
if isinstance(other, cfix):
assert self.f == other.f
sgn = cint(1 - 2 * ((self < 0) ^ (other < 0)))
absolute = self.v * other.v * sgn
val = sgn * (absolute >> self.f)
return cfix._new(val, k=self.k, f=self.f)
elif isinstance(other, sfix):
return NotImplemented
else:
raise CompilerError('Invalid type %s for cfix.__mul__' % type(other))
def positive_mul(self, other):
assert isinstance(other, float)
assert other >= 0
v = self.v * int(round(other * 2 ** self.f))
return self._new(v >> self.f, k=self.k, f=self.f)
@vectorize
def __sub__(self, other):
""" Clear fixed-point subtraction.
:param other: cfix/cint/regint/int """
other = self.parse_type(other)
if isinstance(other, cfix):
return cfix._new(self.v - other.v, k=self.k, f=self.f)
elif isinstance(other, sfix):
return sfix._new(self.v - other.v, k=self.k, f=self.f)
else:
return NotImplemented
@vectorize
def __neg__(self):
""" Clear fixed-point negation. """
# cfix type always has .v
return cfix._new(-self.v, f=self.f, k=self.k)
def __rsub__(self, other):
return -self + other
__rsub__.__doc__ = __sub__.__doc__
@vectorize
def __eq__(self, other):
""" Clear fixed-point comparison.
:param other: cfix/cint/regint/int
:return: 0/1
:rtype: regint """
other = self.parse_type(other)
if isinstance(other, cfix):
return self.v == other.v
elif isinstance(other, sfix):
return other.v.equal(self.v, self.k)
else:
raise NotImplementedError
@vectorize
def __lt__(self, other):
""" Clear fixed-point comparison. """
other = self.parse_type(other)
if isinstance(other, cfix):
assert self.k == other.k
return self.v.less_than(other.v, self.k)
elif isinstance(other, sfix):
if(self.k != other.k or self.f != other.f):
raise TypeError('Incompatible fixed point types in comparison')
return other.v.greater_than(self.v, self.k)
else:
raise NotImplementedError
@vectorize
def __le__(self, other):
""" Clear fixed-point comparison. """
other = self.parse_type(other)
if isinstance(other, cfix):
return 1 - (self > other)
elif isinstance(other, sfix):
return other.v.greater_equal(self.v, self.k)
else:
raise NotImplementedError
@vectorize
def __gt__(self, other):
""" Clear fixed-point comparison. """
other = self.parse_type(other)
if isinstance(other, cfix):
return other.__lt__(self)
elif isinstance(other, sfix):
return other.v.less_than(self.v, self.k)
else:
raise NotImplementedError
@vectorize
def __ge__(self, other):
""" Clear fixed-point comparison. """
other = self.parse_type(other)
if isinstance(other, cfix):
return 1 - (self < other)
elif isinstance(other, sfix):
return other.v.less_equal(self.v, self.k)
else:
raise NotImplementedError
@vectorize
def __ne__(self, other):
""" Clear fixed-point comparison. """
other = self.parse_type(other)
if isinstance(other, cfix):
return self.v != other.v
elif isinstance(other, sfix):
return other.v.not_equal(self.v, self.k)
else:
raise NotImplementedError
for op in __le__, __lt__, __ge__, __gt__, __ne__:
op.__doc__ = __eq__.__doc__
del op
@vectorize
def __truediv__(self, other):
""" Clear fixed-point division.
:param other: cfix/cint/regint/int """
other = self.parse_type(other)
if isinstance(other, cfix):
return cfix._new(library.cint_cint_division(
self.v, other.v, self.k, self.f), k=self.k, f=self.f)
elif isinstance(other, sfix):
assert self.k == other.k
assert self.f == other.f
return sfix._new(library.FPDiv(self.v, other.v, self.k, self.f,
nearest=sfix.round_nearest),
k=self.k, f=self.f)
else:
raise TypeError('Incompatible fixed point types in division')
@vectorize
def __rtruediv__(self, other):
""" Fixed-point division.
:param other: sfix/sint/cfix/cint/regint/int """
other = self.parse_type(other)
return other / self
def reveal(self):
return self
[docs] @vectorize
def print_plain(self):
""" Clear fixed-point output. """
print_float_plain(cint.conv(self.v), cint(-self.f), \
cint(0), cint(0), cint(0))
def output_if(self, cond):
cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size))
[docs] def binary_output(self, player=None):
""" Write double-precision floating-point number to
``Player-Data/Binary-Output-P<playerno>-<threadno>``.
:param player: only output on given player (default all)
"""
if player == None:
player = -1
if not util.is_constant(player):
raise CompilerError('Player number must be known at compile time')
set_global_vector_size(self.size)
floatoutput(player, self.v, cint(-self.f), cint(0), cint(0))
reset_global_vector_size()
def link(self, other):
self.v.link(other.v)
def update(self, other):
self.v.update(other.v)
class _single(_number, _secret_structure):
""" Representation as single integer preserving the order """
""" E.g. fixed-point numbers """
__slots__ = ['v']
round_nearest = False
""" Whether to round deterministically to nearest instead of
probabilistically, e.g. after fixed-point multiplication. """
@vectorized_classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
"""
Securely obtain shares of values input by a client via
:py:func:`sint.receive_from_client`. Assumes client
has already converted values to integer representation.
:param n: number of inputs (int)
:param client_id: regint
:param size: vector size (default 1)
:returns: list of length ``n``
"""
sint_inputs = cls.int_type.receive_from_client(n, client_id,
message_type)
return list(map(cls._new, sint_inputs))
@classmethod
def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients via :py:func:`sint.reveal_to_clients`.
:param clients: client ids (list or array)
:param values: list of values of this class
"""
cls.int_type.reveal_to_clients(clients, [x.v for x in values])
@vectorized_classmethod
def write_shares_to_socket(cls, client_id, values,
message_type=ClientMessageType.NoType):
""" Send shares of integer representations of a list of values
to a specified client socket.
:param client_id: regint
:param values: list of values of this type
"""
cls.int_type.write_shares_to_socket(
client_id, [x.v for x in values], message_type)
@vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
return util.untuplify([cls._new(x) for x in util.tuplify(
cls.int_type.read_from_socket(client_id, n))])
@classmethod
def write_to_socket(cls, client_id, values):
cls.int_type.write_to_socket(client_id, [x.v for x in values])
def write_fully_to_socket(self, client_id):
self.v.write_fully_to_socket(client_id)
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
return cls._new(cls.int_type.load_mem(address))
@classmethod
@read_mem_value
def conv(cls, other):
if isinstance(other, cls):
return other
elif isinstance(other, (list, tuple)):
return type(other)(cls.conv(x) for x in other)
else:
return cls(other)
@classmethod
def coerce(cls, other):
return cls.conv(other)
@classmethod
def malloc(cls, size, creator_tape=None):
return cls.int_type.malloc(size, creator_tape=creator_tape)
@classmethod
def free(cls, addr):
return cls.int_type.free(addr)
@classmethod
def n_elements(cls):
return cls.int_type.n_elements()
@classmethod
def mem_size(cls):
return cls.int_type.mem_size()
@classmethod
def dot_product(cls, x, y, res_params=None):
""" Secret dot product.
:param x: iterable of appropriate secret type
:param y: iterable of appropriate secret type and same length """
return cls.unreduced_dot_product(x, y, res_params).reduce_after_mul()
@classmethod
def unreduced_dot_product(cls, x, y, res_params=None):
dp = cls.int_type.dot_product([xx.pre_mul() for xx in x],
[yy.pre_mul() for yy in y])
return x[0].unreduced(dp, y[0], res_params, len(x))
@classmethod
def row_matrix_mul(cls, row, matrix, res_params=None):
int_matrix = [y.get_vector().pre_mul() for y in matrix]
col = cls.int_type.row_matrix_mul([x.pre_mul() for x in row],
int_matrix)
res = row[0].unreduced(col, matrix[0][0], res_params,
len(row)).reduce_after_mul()
return res
@classmethod
def matrix_mul(cls, A, B, n, res_params=None):
AA = A.pre_mul()
BB = B.pre_mul()
CC = cls.int_type.matrix_mul(AA, BB, n)
res = A.unreduced(CC, B, res_params, n).reduce_after_mul()
return res
@classmethod
def read_from_file(cls, *args, **kwargs):
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
Precision must be the same as when storing. See :ref:`this
section <persistence>` for details on the data format.
:param start: starting position in number of shares from beginning
(int/regint/cint)
:param n_items: number of items (int)
:param crash_if_missing: crash if file not found (default)
:returns: destination for final position, -1 for eof reached,
or -2 for file not found (regint)
:returns: list of shares
"""
stop, shares = cls.int_type.read_from_file(*args, **kwargs)
return stop, [cls._new(x) for x in shares]
@classmethod
def write_to_file(cls, shares, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
section <persistence>` for details on the data format.
:param shares: (list or iterable of sfix)
:param position: start position (int/regint/cint),
defaults to end of file
"""
cls.int_type.write_to_file([x.v for x in shares], position)
def store_in_mem(self, address):
""" Store in memory by public address. """
self.v.store_in_mem(address)
@property
def size(self):
return self.v.size
def sizeof(self):
return self.size
def __len__(self):
""" Vector length. """
return len(self.v)
@vectorize
def __sub__(self, other):
""" Subtraction.
:param other: appropriate public or secret (incl. sint/cint/regint/int) """
other = self.coerce(other)
return self + (-other)
def __rsub__(self, other):
return -self + other
__rsub__.__doc__ = __sub__.__doc__
@vectorize
def __eq__(self, other):
""" Comparison.
:param other: appropriate public or secret (incl. sint/cint/regint/int)
:return: 0/1
:rtype: same as internal representation"""
other = self.coerce(other)
if isinstance(other, (cfix, _single)):
return self.v.equal(other.v, self.k)
else:
raise NotImplementedError
@vectorize
def __le__(self, other):
other = self.coerce(other)
if isinstance(other, (cfix, _single)):
return self.v.less_equal(other.v, self.k)
else:
raise NotImplementedError
@vectorize
def __lt__(self, other):
other = self.coerce(other)
if isinstance(other, (cfix, _single)):
return self.v.less_than(other.v, self.k)
else:
raise NotImplementedError
@vectorize
def __ge__(self, other):
other = self.coerce(other)
if isinstance(other, (cfix, _single)):
return self.v.greater_equal(other.v, self.k)
else:
raise NotImplementedError
@vectorize
def __gt__(self, other):
other = self.coerce(other)
if isinstance(other, (cfix, _single)):
return self.v.greater_than(other.v, self.k)
else:
raise NotImplementedError
@vectorize
def __ne__(self, other):
other = self.coerce(other)
if isinstance(other, (cfix, _single)):
return self.v.not_equal(other.v, self.k)
else:
raise NotImplementedError
for op in __le__, __lt__, __ge__, __gt__, __ne__:
op.__doc__ = __eq__.__doc__
del op
def link(self, other):
self.v.link(other.v)
def get_vector(self):
return self
class _fix(_single):
""" Secret fixed point type. """
__slots__ = ['v', 'f', 'k']
is_clear = False
def set_precision(cls, f, k = None):
cls.f = f
# default bitlength = 2*precision
if k is None:
cls.k = 2 * f
else:
cls.k = k
set_precision.__doc__ = cfix.set_precision.__doc__
set_precision = classmethod(set_precision)
@classmethod
def set_precision_from_args(cls, program, adapt_ring=False):
f = None
k = None
for arg in program.args:
m = re.match('f([0-9]+)$', arg)
if m:
f = int(m.group(1))
m = re.match('k([0-9]+)$', arg)
if m:
k = int(m.group(1))
if f is not None:
print ('Setting fixed-point precision to %d/%s' % (f, k))
cls.set_precision(f, k)
cfix.set_precision(f, k)
elif k is not None:
raise CompilerError('need to set fractional precision')
if 'nearest' in program.args:
print('Nearest rounding instead of probabilistic '
'for fixed-point computation')
cls.round_nearest = True
if adapt_ring and program.options.ring \
and 'fix_ring' not in program.args \
and 2 * cls.k > int(program.options.ring):
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
if need != int(program.options.ring):
print('Changing computation modulus to 2^%d' % need)
program.set_ring_size(need)
@classmethod
def coerce(cls, other, equal_precision=None):
if isinstance(other, (_fix, cls.clear_type)):
return other
else:
return cls.conv(other)
@classmethod
def from_sint(cls, other, k=None, f=None):
""" Convert secret integer.
:param other: sint """
res = cls(k=k, f=f)
res.load_int(cls.int_type.conv(other))
return res
@classmethod
def conv(cls, other):
if isinstance(other, _fix) and (cls.k, cls.f) == (other.k, other.f):
return other
else:
return super(_fix, cls).conv(other)
@classmethod
def _new(cls, other, k=None, f=None):
res = cls(k=k, f=f, initialize=False)
res.v = cls.int_type.conv(other)
return res
@vectorize_init
def __init__(self, _v=None, k=None, f=None, size=None, initialize=True):
if k is None:
k = self.k
else:
self.k = k
if f is None:
f = self.f
else:
self.f = f
assert k is not None
assert f is not None
def adjust(v):
f_diff = v.f - f
v = v.v
if f_diff < 0:
v <<= -f_diff
elif f_diff > 0:
v >>= f_diff
return v
if _v is None:
if initialize:
self.v = self.int_type(0)
else:
return
elif isinstance(_v, self.int_type):
self.load_int(_v)
elif isinstance(_v, cfix.scalars):
self.v = self.int_type(cfix.int_rep(_v, f=f, k=k), size=size)
elif isinstance(_v, self.float_type):
p = (f + _v.p)
b = (p.greater_equal(0, _v.vlen))
a = b*(_v.v << (p)) + (1-b)*(_v.v >> (-p))
self.v = (1-2*_v.s)*a
elif isinstance(_v, type(self)):
self.v = _v.v
elif isinstance(_v, cfix):
self.v = self.int_type(adjust(_v))
elif isinstance(_v, (MemValue, MemFix)):
#this is a memvalue object
self.v = type(self)(_v.read()).v
elif isinstance(_v, (list, tuple)):
self.v = self.int_type(list(self.conv(x).v for x in _v))
elif isinstance(_v, personal):
self.v = self.int_type(personal(_v.player, adjust(_v._v)))
else:
raise CompilerError('cannot convert %s to sfix' % _v)
if not isinstance(self.v, self.int_type):
raise CompilerError('sfix conversion failure: %s/%s' % (_v, self.v))
def load_int(self, v):
self.v = self.int_type(v) << self.f
def __getitem__(self, index):
return self._new(self.v[index])
@vectorize
def add(self, other):
""" Secret fixed-point addition.
:param other: sfix/cfix/sint/cint/regint/int """
other = self.coerce(other)
if isinstance(other, (_fix, cfix)):
return self._new(self.v + other.v, k=self.k, f=self.f)
elif isinstance(other, cfix.scalars):
tmp = cfix(other, k=self.k, f=self.f)
return self + tmp
else:
return NotImplemented
def mul(self, other):
""" Secret fixed-point multiplication.
:param other: sfix/cfix/sint/cint/regint/int """
if isinstance(other, (sint, cint, regint, int)):
return self._new(self.v * other, k=self.k, f=self.f)
elif isinstance(other, float):
if int(other) == other:
return self.mul(int(other))
v = int(round(other * 2 ** self.f))
if v == 0:
return 0
f = self.f
while v % 2 == 0:
f -= 1
v //= 2
k = len(bin(abs(v))) - 1
other = self.multipliable(v, k, f, self.size)
try:
other = self.coerce(other, equal_precision=False)
except:
return NotImplemented
if isinstance(other, (_fix, self.clear_type)):
k = max(self.k, other.k)
max_f = max(self.f, other.f)
min_f = min(self.f, other.f)
val = self.v.TruncMul(other.v, k + min_f, min_f,
nearest=self.round_nearest)
if 'vec' not in self.__dict__:
return self._new(val, k=k, f=max_f)
else:
return self.vec._new(val, k=k, f=max_f)
elif isinstance(other, cfix.scalars):
scalar_fix = cfix(other)
return self * scalar_fix
else:
return NotImplemented
@vectorize
def __neg__(self):
""" Secret fixed-point negation. """
return self._new(-self.v, k=self.k, f=self.f)
@vectorize
def __truediv__(self, other):
""" Secret fixed-point division.
:param other: sfix/cfix/sint/cint/regint/int """
if util.is_constant_float(other):
assert other != 0
log = math.ceil(math.log(abs(other), 2))
if 2 ** log == other and log < self.f:
return self * 2 ** -log
other_length = self.f + log
if other_length >= self.k - 1:
factor = 2 ** (self.k - other_length - 2)
self *= factor
other *= factor
if util.is_zero(self):
return 0
other = self.coerce(other)
assert self.k == other.k
assert self.f == other.f
if isinstance(other, (_fix, cfix)):
v = library.FPDiv(self.v, other.v, self.k, self.f,
nearest=self.round_nearest)
else:
raise TypeError('Incompatible fixed point types in division')
return self._new(v, k=self.k, f=self.f)
@vectorize
def __rtruediv__(self, other):
""" Secret fixed-point division.
:param other: sfix/cfix/sint/cint/regint/int """
return self.coerce(other) / self
@vectorize
def compute_reciprocal(self):
""" Secret fixed-point reciprocal. """
return type(self)(library.FPDiv(cint(2) ** self.f, self.v, self.k,
self.f, nearest=True))
def reveal(self):
""" Reveal secret fixed-point number.
:return: relevant clear type """
val = self.v.reveal()
class revealed_fix(self.clear_type):
f = self.f
k = self.k
return revealed_fix._new(val)
def bit_decompose(self, n_bits=None):
""" Bit decomposition. """
return self.v.bit_decompose(n_bits or self.k)
def update(self, other):
"""
Update register. Useful in loops like
:py:func:`~Compiler.library.for_range`.
:param other: any convertible type
"""
other = self.conv(other)
assert self.f == other.f
self.v.update(other.v)
[docs]class sfix(_fix):
""" Secret fixed-point number represented as secret integer, by
multiplying with ``2^f`` and then rounding. See :py:class:`sint`
for security considerations of the underlying integer operations.
The secret integer is stored as the :py:obj:`v` member.
It supports basic arithmetic (``+, -, *, /``), returning
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
returning :py:class:`sbitint`. The other operand can be any of
sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()``
and ``**``.
Note that the default precision (16 bits after the dot, 31 bits in
total) only allows numbers up to :math:`2^{31-16-1} \\approx
16000` with the smallest non-zero number being :math:`2^{-16}`.
You can change this using :py:func:`set_precision`.
:params _v: int/float/regint/cint/sint/sfloat
"""
int_type = sint
bit_type = sintbit
clear_type = cfix
get_type = staticmethod(lambda n: sint)
default_type = sint
@classmethod
def get_prec_type(cls, f, k=None):
class sfix_prec(cls):
pass
sfix_prec.set_precision(f, k)
return sfix_prec
@vectorized_classmethod
def get_raw_input_from(cls, player):
return cls._new(cls.int_type.get_raw_input_from(player))
[docs] @vectorized_classmethod
def get_random(cls, lower, upper, symmetric=True, public_randomness=False):
""" Uniform secret random number around centre of bounds.
Actual range can be smaller but never larger.
:param lower: float
:param upper: float
:param symmetric: symmetric distribution at higher cost
:param public_randomness: use public randomness (avoids preprocessing)
:param size: vector size (int, default 1)
"""
if public_randomness:
get_random_int = regint.get_random
get_random_bit = lambda: regint.get_random(1)
else:
get_random_int = cls.int_type.get_random_int
get_random_bit = cls.int_type.get_random_bit
f = cls.f
k = cls.k
log_range = int(math.log(upper - lower, 2))
n_bits = log_range + cls.f
gen_range = (2 ** (n_bits) - 1) / 2 ** cls.f
diff = upper - lower
factor = diff / gen_range
real = lambda x: cfix.int_rep(x, f, k) * 2 ** -f
real_range = real(real(factor) * gen_range)
average = lower + 0.5 * (upper - lower)
lower = average - 0.5 * real_range
upper = average + 0.5 * real_range
r = cls._new(get_random_int(n_bits)) * factor + lower
if symmetric:
lowest = math.floor(lower * 2 ** cls.f) / 2 ** cls.f
highest = math.ceil(upper * 2 ** cls.f) / 2 ** cls.f
if program.verbose:
print('randomness range [%f,%f], '
'fringes half the probability' % \
(lowest, highest))
return get_random_bit().if_else(r, -r + 2 * average)
else:
if program.verbose:
print('randomness range [%f,%f], %d bits' % \
(real(lower), real(lower) + real_range, n_bits))
return r
@classmethod
def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None):
# pre-multiplication must be identity
tmp = cls.int_type.direct_matrix_mul(A, B, n, m, l, indices=indices)
res = unreduced_sfix._new(tmp)
if reduce:
res = res.reduce_after_mul()
return res
[docs] @classmethod
def dot_product(cls, x, y, res_params=None):
""" Secret dot product.
:param x: iterable of appropriate secret type
:param y: iterable of appropriate secret type and same length """
x, y = list(x), list(y)
if res_params is None:
if isinstance(x[0], cls.int_type):
x, y = y, x
if isinstance(y[0], cls.int_type):
return cls._new(cls.int_type.dot_product((xx.v for xx in x), y),
k=x[0].k, f=x[0].f)
return super().dot_product(x, y, res_params)
def expand_to_vector(self, size):
return self._new(self.v.expand_to_vector(size), k=self.k, f=self.f)
@read_mem_value
def coerce(self, other, equal_precision=True):
res = parse_type(other, k=self.k, f=self.f)
if equal_precision:
# check parameters if available
try:
assert res.k == self.k
assert res.f == self.f
except AttributeError:
pass
return res
def hard_conv_me(self, cls):
assert cls == sint
return self.v
def mul_no_reduce(self, other, res_params=None):
if not isinstance(other, type(self)):
return self * other
assert self.f == other.f
assert self.k == other.k
return self.unreduced(self.v * other.v)
def pre_mul(self):
return self.v
def unreduced(self, v, other=None, res_params=None, n_summands=1):
assert res_params is None or \
(res_params.k == self.k and res_params.f == self.f)
if other is None:
return unreduced_sfix(v, self.k + self.f, self.f)
else:
return unreduced_sfix(v, self.k + other.f, other.f)
@staticmethod
def multipliable(v, k, f, size):
return cfix._new(cint.conv(v, size=size), k, f)
[docs] def dot(self, other):
""" Dot product with any vector or iterable. """
if isinstance(other, sint):
return self._new(sint.dot_product(self.v, other), k=self.k, f=self.f)
elif isinstance(other, sfix):
assert self.k == other.k
assert self.f == other.f
return self._new(sint.dot_product(self.v, other.v).round(
self.k + other.f, self.f, nearest=self.round_nearest,
signed=True), k=self.k, f=self.f)
elif isinstance(other, (_int, cfix)):
return (self * other).sum()
else:
other = list(other)
assert len(self) == len(other)
return sum(a * b for a, b in zip(self, other))
[docs] def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
:param player: public integer (int/regint/cint)
:returns: :py:class:`personal`
"""
return personal(player, cfix._new(self.v.reveal_to(player)._v,
self.k, self.f))
def secure_shuffle(self, *args, **kwargs):
return self._new(self.v.secure_shuffle(*args, **kwargs),
k=self.k, f=self.f)
def secure_permute(self, *args, **kwargs):
return self._new(self.v.secure_permute(*args, **kwargs),
k=self.k, f=self.f)
def prefix_sum(self):
return self._new(self.v.prefix_sum(), k=self.k, f=self.f)
def sum(self):
return self._new(self.v.sum())
def get_reverse_vector(self):
return self._new(self.v.get_reverse_vector(), k=self.k, f=self.f)
def get_vector(self, *args, **kwargs):
return self._new(self.v.get_vector(*args, **kwargs), k=self.k, f=self.f)
@classmethod
def concat(cls, parts):
parts = list(parts)
int_parts = []
f = parts[0].f
k = parts[0].k
for part in parts:
assert part.f == f
assert part.k == k
int_parts.append(part.v)
return cls._new(cls.int_type.concat(int_parts), k=k, f=f)
@classmethod
def zip(cls, *parts):
int_parts = []
f = parts[0].f
k = parts[0].k
for part in parts:
assert part.f == f
assert part.k == k
int_parts.append(part.v)
return cls._new(cls.int_type.zip(*int_parts), k=k, f=f)
def __repr__(self):
return '<sfix{f=%d,k=%d} at %s>' % (self.f, self.k, self.v)
class unreduced_sfix(_single):
int_type = sint
@classmethod
def _new(cls, v):
return cls(v, sfix.k + sfix.f, sfix.f)
def __init__(self, v, k, m):
self.v = v
self.k = k
self.m = m
assert self.k is not None
assert self.m is not None
def __add__(self, other):
if is_zero(other):
return self
assert self.k == other.k
assert self.m == other.m
return unreduced_sfix(self.v + other.v, self.k, self.m)
__radd__ = __add__
@vectorize
def reduce_after_mul(self):
v = sfix.int_type.round(self.v, self.k, self.m,
nearest=sfix.round_nearest, signed=True)
return sfix._new(v, k=self.k - self.m, f=self.m)
def update(self, other):
assert self.k == other.k
assert self.m == other.m
self.v.update(other.v)
sfix.unreduced_type = unreduced_sfix
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
class squant(_single):
""" Quantization as in ArXiv:1712.05877v1 """
__slots__ = ['params']
int_type = sint
clamp = True
@classmethod
def set_params(cls, S, Z=0, k=8):
cls.params = squant_params(S, Z, k)
@classmethod
def from_sint(cls, other):
raise CompilerError('sint to squant conversion not implemented')
@classmethod
def conv(cls, other):
if isinstance(other, squant):
return other
else:
return cls(other)
@classmethod
def _new(cls, value, params=None):
res = cls(params=params)
res.v = value
return res
@read_mem_value
def __init__(self, value=None, params=None):
if params is not None:
self.params = params
if value is None:
# need to set v manually
pass
elif isinstance(value, cfix.scalars):
set_global_vector_size(1)
q = util.round_to_int(value / self.S + self.Z)
if util.is_constant(q) and (q < 0 or q >= 2**self.k):
raise CompilerError('%f not quantizable' % value)
self.v = self.int_type(q)
reset_global_vector_size()
elif isinstance(value, squant) and value.params == self.params:
self.v = value.v
else:
raise CompilerError('cannot convert %s to squant' % value)
def __getitem__(self, index):
return type(self)._new(self.v[index], self.params)
def get_params(self):
return self.params
@property
def S(self):
return self.params.S
@property
def Z(self):
return self.params.Z
@property
def k(self):
return self.params.k
def coerce(self, other):
other = self.conv(other)
return self._new(util.expand(other.v, self.size), other.params)
@vectorize
def add(self, other):
other = self.coerce(other)
assert self.get_params() == other.get_params()
return self._new(self.v + other.v - util.expand(self.Z, self.v.size))
def mul(self, other, res_params=None):
return self.mul_no_reduce(other, res_params).reduce_after_mul()
def mul_no_reduce(self, other, res_params=None):
if isinstance(other, (sint, cint, regint)):
return self._new(other * (self.v - self.Z) + self.Z,
params=self.get_params())
other = self.coerce(other)
tmp = (self.v - self.Z) * (other.v - other.Z)
return _unreduced_squant(tmp, (self.get_params(), other.get_params()),
res_params=res_params)
def pre_mul(self):
return self.v - util.expand(self.Z, self.v.size)
def unreduced(self, v, other, res_params=None, n_summands=1):
return _unreduced_squant(v, (self.get_params(), other.get_params()),
res_params, n_summands)
@vectorize
def for_mux(self, other):
other = self.coerce(other)
assert self.params == other.params
f = lambda x: self._new(x, self.params)
return f, self.v, other.v
@vectorize
def __neg__(self):
return self._new(-self.v + 2 * util.expand(self.Z, self.v.size))
class _unreduced_squant(Tape._no_truth):
def __init__(self, v, params, res_params=None, n_summands=1):
self.v = v
self.params = params
self.n_summands = n_summands
self.res_params = res_params or params[0]
def __add__(self, other):
if is_zero(other):
return self
assert self.params == other.params
assert self.res_params == other.res_params
return _unreduced_squant(self.v + other.v, self.params, self.res_params,
self.n_summands + other.n_summands)
__radd__ = __add__
def reduce_after_mul(self):
return squant_params.conv(self.res_params).reduce(self)
class squant_params(object):
max_n_summands = 2048
@staticmethod
def conv(other):
if isinstance(other, squant_params):
return other
else:
return squant_params(*other)
def __init__(self, S, Z=0, k=8):
try:
self.S = float(S)
except:
self.S = S
self.Z = MemValue.if_necessary(Z)
self.k = k
self._store = {}
if program.options.ring:
# cheaper probabilistic truncation
self.max_length = int(program.options.ring) - 1
else:
# safe choice for secret shift
self.max_length = 71
def __iter__(self):
yield self.S
yield self.Z
yield self.k
def is_constant(self):
return util.is_constant_float(self.S) and util.is_constant(self.Z)
def get(self, input_params, n_summands):
p = input_params
M = p[0].S * p[1].S / self.S
logM = util.log2(M)
n_shift = self.max_length - p[0].k - p[1].k - util.log2(n_summands)
if util.is_constant_float(M):
n_shift -= logM
int_mult = int(round(M * 2 ** (n_shift)))
else:
int_mult = MemValue(M.v << (n_shift + M.p))
shifted_Z = MemValue.if_necessary(self.Z << n_shift)
return n_shift, int_mult, shifted_Z
def precompute(self, *input_params):
self._store[input_params] = self.get(input_params, self.max_n_summands)
def get_stored(self, unreduced):
assert unreduced.n_summands <= self.max_n_summands
return self._store[unreduced.params]
def reduce(self, unreduced):
ps = (self,) + unreduced.params
if reduce(operator.and_, (p.is_constant() for p in ps)):
n_shift, int_mult, shifted_Z = self.get(unreduced.params,
unreduced.n_summands)
else:
n_shift, int_mult, shifted_Z = self.get_stored(unreduced)
size = unreduced.v.size
n_shift = util.expand(n_shift, size)
shifted_Z = util.expand(shifted_Z, size)
int_mult = util.expand(int_mult, size)
tmp = unreduced.v * int_mult + shifted_Z
shifted = tmp.round(self.max_length, n_shift,
nearest=squant.round_nearest,
signed=True)
if squant.clamp:
length = max(self.k, self.max_length - n_shift) + 1
top = (1 << self.k) - 1
over = shifted.greater_than(top, length)
under = shifted.less_than(0, length)
shifted = over.if_else(top, shifted)
shifted = under.if_else(0, shifted)
return squant._new(shifted, params=self)
[docs]class sfloat(_number, _secret_structure):
r"""
Secret floating-point number.
Represents :math:`(1 - 2s) \cdot (1 - z)\cdot v \cdot 2^p`.
v: significand
p: exponent
z: zero flag
s: sign bit
This uses integer operations internally, see :py:class:`sint` for security
considerations.
See `Aliasgari et al. <https://eprint.iacr.org/2012/405.pdf>`_ for
details.
The type supports basic arithmetic (``+, -, *, /``), returning
:py:class:`sfloat`, and comparisons (``==, !=, <, <=, >, >=``),
returning :py:class:`sint`. The other operand can be any of
sint/cfix/regint/cint/int/float.
This data type only works with arithmetic computation.
:param v: initialization (sfloat/sfix/float/int/sint/cint/regint)
"""
__slots__ = ['v', 'p', 'z', 's', 'size']
# single precision
vlen = 24
plen = 8
round_nearest = False
@staticmethod
def n_elements():
return 4
@classmethod
def malloc(cls, size, creator_tape=None):
return program.malloc(size * cls.n_elements(), sint,
creator_tape=creator_tape)
@classmethod
def is_address_tuple(cls, address):
if isinstance(address, (list, tuple)):
assert(len(address) == cls.n_elements())
return True
return False
[docs] @vectorized_classmethod
def load_mem(cls, address, mem_type=None):
""" Load from memory by public address. """
size = get_global_vector_size()
if cls.is_address_tuple(address):
return sfloat(*(sint.load_mem(a, size=size) for a in address))
res = []
for i in range(4):
res.append(sint.load_mem(address + i * size, size=size))
return sfloat(*res)
@classmethod
def set_error(cls, error):
# incompatible with loops
#cls.error += error - cls.error * error
cls.error = error
pass
@classmethod
def conv(cls, other):
if isinstance(other, cls):
return other
else:
return cls(other)
@classmethod
def coerce(cls, other):
return cls.conv(other)
@staticmethod
def convert_float(v, vlen, plen):
if v < 0:
s = 1
else:
s = 0
if v == 0:
v = 0
p = 0
z = 1
else:
p = int(math.floor(math.log(abs(v), 2))) - vlen + 1
vv = v
v = int(round(abs(v) * 2 ** (-p)))
if v == 2 ** vlen:
p += 1
v //= 2
z = 0
if p < -2 ** (plen - 1):
print('Warning: %e truncated to zero' % vv)
v, p, z = 0, 0, 1
if p >= 2 ** (plen - 1):
raise CompilerError('Cannot convert %s to float ' \
'with %d exponent bits' % (vv, plen))
return v, p, z, s
@vectorize_init
@read_mem_value
def __init__(self, v, p=None, z=None, s=None, size=None):
if program.options.binary:
raise CompilerError(
'floating-point operations not supported with binary circuits')
self.size = get_global_vector_size()
if p is None:
if isinstance(v, sfloat):
p = v.p
z = v.z
s = v.s
v = v.v
elif isinstance(v, sfix):
f = v.f
v, p, z, s = floatingpoint.Int2FL(v.v, v.k,
self.vlen)
p = p - f
elif util.is_constant_float(v):
v, p, z, s = self.convert_float(v, self.vlen, self.plen)
else:
v, p, z, s = floatingpoint.Int2FL(sint.conv(v),
program.bit_length,
self.vlen)
if isinstance(v, int):
if not ((v >= 2**(self.vlen-1) and v < 2**(self.vlen)) or v == 0):
raise CompilerError('Floating point number malformed: significand')
if isinstance(p, int):
if not (p >= -2**(self.plen - 1) and p < 2**(self.plen - 1)):
raise CompilerError('Floating point number malformed: exponent %d not unsigned %d-bit integer' % (p, self.plen))
if isinstance(z, int):
if not (z == 0 or z == 1):
raise CompilerError('Floating point number malformed: zero bit')
if isinstance(s, int):
if not (s == 0 or s == 1):
raise CompilerError('Floating point number malformed: sign')
# copying necessary for update to work properly
self.v = sint(v)
self.p = sint(p)
self.z = sint(z)
self.s = sint(s)
def __getitem__(self, index):
return sfloat(*(x[index] for x in self))
def __iter__(self):
yield self.v
yield self.p
yield self.z
yield self.s
[docs] def store_in_mem(self, address):
""" Store in memory by public address. """
if self.is_address_tuple(address):
for a, x in zip(address, self):
x.store_in_mem(a)
return
for i,x in enumerate((self.v, self.p, self.z, self.s)):
x.store_in_mem(address + i * self.size)
def sizeof(self):
return self.size * self.n_elements()
@vectorize
def add(self, other):
""" Secret floating-point addition.
:param other: sfloat/float/sfix/sint/cint/regint/int """
other = self.conv(other)
if isinstance(other, sfloat):
a,c,d,e = [sint() for i in range(4)]
t = sint()
t2 = sint()
v1 = self.v
v2 = other.v
p1 = self.p
p2 = other.p
s1 = self.s
s2 = other.s
z1 = self.z
z2 = other.z
a = p1.less_than(p2, self.plen)
b = floatingpoint.EQZ(p1 - p2, self.plen)
c = v1.less_than(v2, self.vlen)
ap1 = a*p1
ap2 = a*p2
aneg = 1 - a
bneg = 1 - b
cneg = 1 - c
av1 = a*v1
av2 = a*v2
cv1 = c*v1
cv2 = c*v2
pmax = ap2 + p1 - ap1
pmin = p2 - ap2 + ap1
vmax = bneg*(av2 + v1 - av1) + b*(cv2 + v1 - cv1)
vmin = bneg*(av1 + v2 - av2) + b*(cv1 + v2 - cv2)
s3 = s1 + s2 - 2 * s1 * s2
comparison.LTZ(d, self.vlen + pmin - pmax + sfloat.round_nearest,
self.plen)
pow_delta = floatingpoint.Pow2((1 - d) * (pmax - pmin),
self.vlen + 1 + sfloat.round_nearest)
# deviate from paper for more precision
#v3 = 2 * (vmax - s3) + 1
v3 = vmax
v4 = vmax * pow_delta + (1 - 2 * s3) * vmin
to_trunc = (d * v3 + (1 - d) * v4)
if program.options.ring:
to_trunc <<= 1 + sfloat.round_nearest
v = floatingpoint.TruncInRing(to_trunc,
2 * (self.vlen + 1 +
sfloat.round_nearest),
pow_delta)
else:
to_trunc *= two_power(self.vlen + sfloat.round_nearest)
v = to_trunc * floatingpoint.Inv(pow_delta)
comparison.Trunc(t, v, 2 * self.vlen + 1 + sfloat.round_nearest,
self.vlen - 1, signed=False)
v = t
u = floatingpoint.BitDec(v, self.vlen + 2 + sfloat.round_nearest,
self.vlen + 2 + sfloat.round_nearest,
list(range(1 + sfloat.round_nearest,
self.vlen + 2 + sfloat.round_nearest)))
# using u[0] doesn't seem necessary
h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1])
p0 = self.vlen + 1 - sum(h)
pow_p0 = 1 + sum([two_power(i) * (1 - h[i]) for i in range(len(h))])
if self.round_nearest:
t2, overflow = \
floatingpoint.TruncRoundNearestAdjustOverflow(pow_p0 * v,
self.vlen + 3,
self.vlen)
p0 = p0 - overflow
else:
comparison.Trunc(t2, pow_p0 * v, self.vlen + 2, 2, signed=False)
v = t2
# deviate for more precision
#p = pmax - p0 + 1 - d
p = pmax - p0 + 1
zz = self.z*other.z
zprod = 1 - self.z - other.z + zz
v = zprod*t2 + self.z*v2 + other.z*v1
z = floatingpoint.EQZ(v, self.vlen)
p = (zprod*p + self.z*p2 + other.z*p1)*(1 - z)
s = (1 - b)*(a*other.s + aneg*self.s) + b*(c*other.s + cneg*self.s)
s = zprod*s + (other.z - zz)*self.s + (self.z - zz)*other.s
return sfloat(v, p, z, s)
else:
return NotImplemented
@vectorize_max
def mul(self, other):
""" Secret floating-point multiplication.
:param other: sfloat/float/sfix/sint/cint/regint/int """
other = self.conv(other)
if isinstance(other, sfloat):
v1 = sint()
v2 = sint()
b = sint()
c2expl = cint()
comparison.ld2i(c2expl, self.vlen)
if sfloat.round_nearest:
v1 = comparison.TruncRoundNearest(self.v*other.v, 2*self.vlen,
self.vlen-1)
else:
comparison.Trunc(v1, self.v*other.v, 2*self.vlen, self.vlen-1,
signed=False)
t = v1 - c2expl
comparison.LTZ(b, t, self.vlen+1)
comparison.Trunc(v2, b*v1 + v1, self.vlen+1, 1, signed=False)
z1, z2, s1, s2, p1, p2 = (x.expand_to_vector() for x in \
(self.z, other.z, self.s, other.s,
self.p, other.p))
z = z1 + z2 - self.z*other.z # = OR(z1, z2)
s = s1 + s2 - self.s*other.s*2 # = XOR(s1,s2)
p = (p1 + p2 - b + self.vlen)*(1 - z)
return sfloat(v2, p, z, s)
else:
return NotImplemented
def __sub__(self, other):
""" Secret floating-point subtraction.
:param other: sfloat/float/sfix/sint/cint/regint/int """
return self + -other
def __rsub__(self, other):
return -self + other
__rsub__.__doc__ = __sub__.__doc__
@vectorize
def __truediv__(self, other):
""" Secret floating-point division.
:param other: sfloat/float/sfix/sint/cint/regint/int """
other = self.conv(other)
v = floatingpoint.SDiv(self.v, other.v + other.z * (2**self.vlen - 1),
self.vlen, round_nearest=self.round_nearest)
b = v.less_than(two_power(self.vlen-1), self.vlen + 1)
overflow = v.greater_equal(two_power(self.vlen), self.vlen + 1)
underflow = v.less_than(two_power(self.vlen-2), self.vlen + 1)
v = (v + b * v) * (1 - overflow) * (1 - underflow) + \
overflow * (2**self.vlen - 1) + \
underflow * (2**(self.vlen-1)) * (1 - self.z)
p = (1 - self.z) * (self.p - other.p - self.vlen - b + 1)
z = self.z
s = self.s + other.s - 2 * self.s * other.s
sfloat.set_error(other.z)
return sfloat(v, p, z, s)
def __rtruediv__(self, other):
return self.conv(other) / self
__rtruediv__.__doc__ = __truediv__.__doc__
@vectorize
def __neg__(self):
""" Secret floating-point negation. """
return sfloat(self.v, self.p, self.z, (1 - self.s) * (1 - self.z))
@vectorize
def __lt__(self, other):
""" Secret floating-point comparison.
:param other: sfloat/float/sfix/sint/cint/regint/int
:return: 0/1 (sint) """
other = self.conv(other)
if isinstance(other, sfloat):
z1 = self.z
z2 = other.z
s1 = self.s
s2 = other.s
a = self.p.less_than(other.p, self.plen)
c = floatingpoint.EQZ(self.p - other.p, self.plen)
d = ((1 - 2*self.s)*self.v).less_than((1 - 2*other.s)*other.v, self.vlen + 1)
cd = c*d
ca = c*a
b1 = cd + a - ca
b2 = cd + 1 + ca - c - a
s12 = self.s*other.s
z12 = self.z*other.z
b = (z1 - z12)*(1 - s2) + (z2 - z12)*s1 + (1 + z12 - z1 - z2)*(s1 - s12 + (1 + s12 - s1 - s2)*b1 + s12*b2)
return b
else:
return NotImplemented
def __ge__(self, other):
""" Secret floating-point comparison. """
return 1 - (self < other)
@vectorize
def __gt__(self, other):
""" Secret floating-point comparison. """
return self.conv(other) < self
@vectorize
def __le__(self, other):
""" Secret floating-point comparison. """
return self.conv(other) >= self
@vectorize
def __eq__(self, other):
""" Secret floating-point comparison. """
other = self.conv(other)
# the sign can be both ways for zeroes
both_zero = self.z * other.z
return floatingpoint.EQZ(self.v - other.v, self.vlen) * \
floatingpoint.EQZ(self.p - other.p, self.plen) * \
(1 - self.s - other.s + 2 * self.s * other.s) * \
(1 - both_zero) + both_zero
def __ne__(self, other):
""" Secret floating-point comparison. """
return 1 - (self == other)
for op in __gt__, __le__, __ge__, __eq__, __ne__:
op.__doc__ = __lt__.__doc__
del op
def log2(self):
up = self.v.greater_than(1 << (self.vlen - 1), self.vlen)
return self.p + self.vlen - 1 + up
[docs] def round_to_int(self):
""" Secret floating-point rounding to integer.
:return: sint """
direction = self.p.greater_equal(-self.vlen, self.plen)
right = self.v.right_shift(-self.p - 1, self.vlen + 1)
up = right.mod2m(1, self.vlen + 1)
right = right.right_shift(1, self.vlen + 1) + up
abs_value = direction * right
return self.s.if_else(-abs_value, abs_value)
def value(self):
# Gets actual floating point value, if emulation is enabled.
return (1 - 2*self.s.value)*(1 - self.z.value)*self.v.value/float(2**self.p.value)
[docs] def reveal(self):
""" Reveal secret floating-point number.
:return: cfloat """
return cfloat(self.v.reveal(), self.p.reveal(), self.z.reveal(), self.s.reveal())
[docs] def update(self, other):
"""
Update register. Useful in loops like
:py:func:`~Compiler.library.for_range`.
:param other: any convertible type
"""
self.v.update(other.v)
self.p.update(other.p)
self.z.update(other.z)
self.s.update(other.s)
def for_mux(self, other):
other = self.coerce(other)
f = lambda x: type(self)(*x)
return f, sint(list(self)), sint(list(other))
[docs]class cfloat(Tape._no_truth):
""" Helper class for printing revealed sfloats. """
__slots__ = ['v', 'p', 'z', 's', 'nan']
@vectorize_init
def __init__(self, v, p=None, z=None, s=None, nan=0):
""" Parameters as with :py:class:`sfloat` but public. """
if s is None:
parts = [cint.conv(x) for x in (v.v, v.p, v.z, v.s, v.nan)]
else:
parts = [cint.conv(x) for x in (v, p, z, s, nan)]
self.v, self.p, self.z, self.s, self.nan = parts
@property
def size(self):
return self.v.size
[docs] @vectorize
def print_float_plain(self):
""" Output. """
print_float_plain(self.v, self.p, self.z, self.s, self.nan)
[docs] def binary_output(self, player=None):
""" Write double-precision floating-point number to
``Player-Data/Binary-Output-P<playerno>-<threadno>``.
:param player: only output on given player (default all)
"""
if player == None:
player = -1
floatoutput(player, self.v, self.p, self.z, self.s)
sfix.float_type = sfloat
_types = {
'c': cint,
's': sint,
'sg': sgf2n,
'cg': cgf2n,
'ci': regint,
}
def _get_type(t):
if t in _types:
return _types[t]
else:
return t
class _vectorizable:
@classmethod
def check(cls, index, length, sizes):
if isinstance(index, _clear):
index = regint.conv(index)
if length is not None:
from .GC.types import cbits
if isinstance(index, int):
index += length * (index < 0)
if index >= length or index < 0:
raise IndexError('index %s, length %s' % \
(str(index), str(length)))
elif cls.check_indices and not isinstance(index, cbits):
library.runtime_error_if(
(index >= length).bit_or(index < 0),
'overflow: %s/%s', index, sizes)
return index
def reveal_to_clients(self, clients):
""" Reveal contents to list of clients.
:param clients: list or array of client identifiers
"""
self.value_type.reveal_to_clients(clients, [self.get_vector()])
[docs]class Array(_vectorizable):
"""
Array accessible by public index. That is, ``a[i]`` works for an
array ``a`` and ``i`` being a :py:class:`regint`,
:py:class:`cint`, or a Python integer.
:param length: compile-time integer (int) or :py:obj:`None`
for unknown length (need to specify :py:obj:`address`)
:param value_type: basic type
:param address: if given (regint/int), the array will not be allocated
You can convert between arrays and register vectors by using slice
indexing. This allows for element-wise operations as long as
supported by the basic type. The following adds 10 secret integers
from the first two parties::
a = sint.Array(10)
a.input_from(0)
b = sint.Array(10)
b.input_from(1)
a[:] += b[:]
Arrays aren't initialized on creation, you need to call
:py:func:`assign_all` to initialize them to a constant value.
"""
check_indices = True
[docs] @classmethod
def create_from(cls, l):
""" Convert Python iterator or vector to array or copy another array.
Basic type will be taken from first element, further elements
must to be convertible to that.
:param l: Python iterable, register vector, or array
:returns: :py:class:`Array` of appropriate type containing the contents
of :py:obj:`l`
"""
if isinstance(l, cls):
res = l.same_shape()
res[:] = l[:]
return res
if isinstance(l, _number):
tmp = l
t = type(l)
else:
tmp = list(l)
t = type(tmp[0])
res = cls(len(tmp), t)
res.assign(tmp)
return res
def __init__(self, length, value_type, address=None, debug=None, alloc=True):
value_type = _get_type(value_type)
self.address = address
self.length = length
self.value_type = value_type
self.address = address
self.address_cache = {}
self.debug = debug
self.creator_tape = program.curr_tape
self.sink = None
if alloc:
self.alloc()
def alloc(self):
if self._address is None:
try:
self.address = self.value_type.malloc(self.length,
self.creator_tape)
except AttributeError:
raise CompilerError('cannot create Array of %s' % \
self.value_type)
def delete(self):
self.value_type.free(self.address)
self.address = None
@property
def address(self):
if self._address is None:
raise CompilerError('trying access unallocated memory')
return self._address
@address.setter
def address(self, address):
self._address = address
@read_mem_value
def get_address(self, index, size=None):
if isinstance(index, (_secret, _single)):
raise CompilerError('need cleartext index')
key = str(index), size or 1
index = self.check(index, self.length, self.length)
if (program.curr_block, key) not in self.address_cache:
n = self.value_type.n_elements()
length = self.length
if n == 1:
# length can be None for single-element arrays
length = 0
base = self.address + index * self.value_type.mem_size()
if size is not None and isinstance(base, _register) \
and not issubclass(self.value_type, _vec):
base = regint._expand_address(base, size)
self.address_cache[program.curr_block, key] = \
util.untuplify([base + i * length \
for i in range(n)])
if self.debug:
library.print_ln_if(index >= self.length, 'OF:' + self.debug)
library.print_ln_if(self.address_cache[program.curr_block, key] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug)
return self.address_cache[program.curr_block, key]
def get_slice(self, index):
if index.stop is None and self.length is None:
raise CompilerError('Cannot slice array of unknown length')
if index.step == 0:
raise CompilerError('slice step cannot be zero')
return index.start or 0, \
index.stop if self.length is None else \
min(index.stop or self.length, self.length), index.step or 1
def __getitem__(self, index):
""" Reading from array.
:param index: public (regint/cint/int/slice)
:return: vector if slice is given, basic type otherwise"""
if isinstance(index, slice):
start, stop, step = self.get_slice(index)
if step == 1:
return self.get_vector(start, stop - start)
else:
res_length = (stop - start - 1) // step + 1
addresses = regint.inc(res_length, start, step)
return self.get_vector(addresses, res_length)
return self._load(self.get_address(index))
def __setitem__(self, index, value):
""" Writing to array.
:param index: public (regint/cint/int)
:param value: convertible for relevant basic type """
if isinstance(index, slice):
start, stop, step = self.get_slice(index)
if step == 1:
return self.assign(value, start)
else:
res_length = (stop - start - 1) // step + 1
addresses = regint.inc(res_length, start, step)
return self.assign(value, addresses)
self._store(value, self.get_address(index))
def to_array(self):
return self
def get_sub(self, start, stop=None):
if stop is None:
stop = start
start = 0
return Array(stop - start, self.value_type,
address=self.address + start)
[docs] def maybe_get(self, condition, index):
""" Return entry if condition is true.
:param condition: 0/1 (regint/cint/int)
:param index: regint/cint/int
"""
return self[condition * index].zero_if_not(condition)
[docs] def maybe_set(self, condition, index, value):
""" Change entry if condition is true.
:param condition: 0/1 (regint/cint/int)
:param index: regint/cint/int
:param value: updated value
"""
if self.sink is None:
self.sink = self.value_type.Array(
1, address=self.value_type.malloc(1, creator_tape=program.tapes[0]))
addresses = (condition.if_else(x, y) for x, y in
zip(util.tuplify(self.get_address(condition * index)),
util.tuplify(self.sink.get_address(0))))
self._store(value, util.untuplify(tuple(addresses)))
# the following two are useful for compile-time lengths
# and thus differ from the usual Python syntax
def get_range(self, start, size):
return [self[start + i] for i in range(size)]
def set_range(self, start, values):
for i, value in enumerate(values):
self[start + i] = value
def _load(self, address):
return self.value_type.load_mem(address)
def _store(self, value, address):
tmp = self.value_type.conv(value)
if not isinstance(tmp, _vec) and tmp.size != self.value_type.mem_size():
raise CompilerError('size mismatch in array assignment')
tmp.store_in_mem(address)
def __len__(self):
if self.length is None:
raise CompilerError('this functionality is not available '
'for variable-length arrays')
return self.length
def total_size(self):
return self.length * self.value_type.n_elements()
@property
def shape(self):
return [self.length]
def __iter__(self):
for i in range(self.length):
yield self[i]
[docs] def same_shape(self, **kwargs):
""" Array of same length and type. """
return Array(self.length, self.value_type, **kwargs)
[docs] def assign(self, other, base=0):
""" Assignment.
:param other: vector/Array/Matrix/MultiArray/iterable of
compatible type and smaller size
:param base: index to start assignment at
"""
try:
other = other.get_vector()
except:
pass
try:
other = self.value_type.conv(other)
other.store_in_mem(self.get_address(base, other.size))
if len(self) != None and util.is_constant(base):
assert len(self) >= other.size + base
except (AttributeError, CompilerError):
if isinstance(other, Array):
@library.for_range_opt(len(other))
def _(i):
self[base + i] = other[i]
else:
for i,j in enumerate(other):
self[base + i] = j
return self
assign_vector = assign
assign_part_vector = assign
[docs] def assign_all(self, value, n_threads=None, conv=True):
""" Assign the same value to all entries.
:param value: convertible to basic type """
from Compiler.GC.types import bits
use_vector = util.is_constant(value) and \
not issubclass(self.value_type, (bits, squant))
if not use_vector:
if conv:
value = self.value_type.conv(value)
if value.size != 1:
raise CompilerError('cannot assign vector to all elements')
mem_value = MemValue(value)
if not util.is_constant(self.length) or program.options.garbled or \
not program.curr_tape.singular:
n_threads = None
if n_threads is not None:
self.address = MemValue.if_necessary(self.address)
@library.multithread(n_threads, self.length, max_size=program.budget)
def _(base, size):
if use_vector:
self.assign_vector(self.value_type(value, size=size), base)
else:
v = mem_value.read()
if isinstance(v, (sint, sfix)):
self.assign_vector(v.expand_to_vector(size), base=base)
else:
@library.for_range_opt(size)
def _(i):
self[base + i] = mem_value
return self
[docs] def get_vector(self, base=0, size=None):
""" Return vector with content.
:param base: starting point (regint/cint/int)
:param size: length (compile-time int) """
size = size or self.length - base
return self.value_type.load_mem(self.get_address(base, size), size=size)
get_part_vector = get_vector
[docs] def get_reverse_vector(self):
""" Return vector with content in reverse order. """
size = self.length
address = regint.inc(size, size - 1, -1)
return self.value_type.load_mem(self.address + address, size=size)
[docs] def get_part(self, base, size):
""" Part array.
:param base: start index (regint/cint/int)
:param size: integer
:returns: Array of same type
"""
return Array(size, self.value_type, self.get_address(base))
[docs] def get(self, indices):
""" Vector from arbitrary indices.
:param indices: regint vector or array
"""
return self.value_type.load_mem(
regint.inc(len(indices), self.address, 0) + indices,
size=len(indices))
def get_slice_addresses(self, slice):
assert self.value_type.n_elements() == 1
assert len(slice) <= self.total_size()
base = regint.inc(len(slice), slice.address, 1, 1)
inc = regint.inc(len(slice), self.address, 1, 1, 1)
addresses = regint.conv(slice.value_type.load_mem(base)) + inc
return addresses
def get_slice_vector(self, slice):
addresses = self.get_slice_addresses(slice)
return self.value_type.load_mem(addresses)
def assign_slice_vector(self, slice, vector):
addresses = self.get_slice_addresses(slice)
vector.store_in_mem(addresses)
[docs] def permute(self, permutation, reverse=False, n_threads=None):
""" Public permutation.
:param permutation: cleartext :py:class`Array` containing number
in :math:`[0,n-1]` where :math:`n` is the length of this array
:param reverse: whether to apply the inverse of the permutation
"""
if reverse:
self.assign_slice_vector(permutation, self.get_vector())
else:
self.assign_vector(self.get_slice_vector(permutation))
[docs] def expand_to_vector(self, index, size):
""" Create vector from single entry.
:param index: regint/cint/int
:param size: int
"""
assert self.value_type.n_elements() == 1
address = regint(size=size)
incint(address, regint(self.get_address(index), size=1), 0)
return self.value_type.load_mem(address, size=size)
def get_mem_value(self, index):
return MemValue(self[index], self.get_address(index))
[docs] def concat(self, other):
""" Concatenate two arrays. """
assert self.value_type == other.value_type
res = Array(len(self) + len(other), self.value_type)
res.assign_vector(self[:])
res.assign_vector(other[:], len(self))
return res
[docs] def read_from_file(self, start, *args, **kwargs):
""" Read content from ``Persistence/Transactions-P<playerno>.data``.
Precision must be the same as when storing if applicable. See
:ref:`this section <persistence>` for details on the data format.
:param start: starting position in number of shares from beginning
(int/regint/cint)
:param crash_if_missing: crash if file not found (default)
:returns: destination for final position, -1 for eof reached,
or -2 for file not found (regint)
"""
start = regint(start)
res = MemValue(0)
@library.multithread(None, len(self), max_size=program.budget)
def _(base, size):
stop, shares = self.value_type.read_from_file(
start, *args, size=size, **kwargs)
self.assign(shares[0], base=base)
start.iadd(size)
res.write(stop)
return res
[docs] def write_to_file(self, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data``. See :ref:`this
section <persistence>` for details on the data format.
:param position: start position (int/regint/cint),
defaults to end of file
"""
if position is not None:
position = regint(position)
@library.multithread(None, len(self), max_size=program.budget)
def _(base, size):
self.value_type.write_to_file(self.get_vector(base=base, size=size),
position)
if position is not None:
position.iadd(size)
[docs] def read_from_socket(self, socket, debug=False):
""" Read content from socket. """
if debug:
library.print_str('reading %s...' % self)
# hard-coded budget for interopability
@library.multithread(None, len(self), max_size=10 ** 6)
def _(base, size):
self.assign_vector(
self.value_type.read_from_socket(socket, size=size), base=base)
if debug:
library.print_ln('done')
[docs] def write_to_socket(self, socket, debug=False):
""" Write content to socket. """
if debug:
library.print_ln('writing %s' % self)
# hard-coded budget for interopability
@library.multithread(None, len(self), max_size=10 ** 6)
def _(base, size):
self.value_type.write_to_socket(
socket, [self.get_vector(base=base, size=size)])
def __add__(self, other):
""" Vector addition.
:param other: vector or container of same length and type that supports operations with type of this array """
if is_zero(other):
return self
return self.get_vector() + other
def __sub__(self, other):
""" Vector subtraction.
:param other: vector or container of same length and type that supports operations with type of this array """
return self.get_vector() - other
def __rsub__(self, other):
return other - self.get_vector()
def __mul__(self, value):
""" Vector multiplication.
:param other: vector or container of same length and type that supports operations with type of this array """
return self.get_vector() * value
def __truediv__(self, value):
""" Vector division.
:param other: vector or container of same length and type that supports operations with type of this array """
return self.get_vector() / value
def __pow__(self, value):
""" Vector power-of computation.
:param other: compile-time integer (int) """
return self.get_vector() ** value
def __eq__(self, other):
return self.get_vector() == other
def __ne__(self, other):
return self.get_vector() != other
def __lt__(self, other):
return self.get_vector() < other
def __le__(self, other):
return self.get_vector() <= other
def __gt__(self, other):
return self.get_vector() > other
def __ge__(self, other):
return self.get_vector() >= other
__radd__ = __add__
__rmul__ = __mul__
def __iadd__(self, other):
self[:] += other.get_vector()
return self
def __isub__(self, other):
self[:] -= other.get_vector()
return self
def __imul__(self, other):
self[:] *= other.get_vector()
return self
def __itruediv__(self, other):
self[:] /= other.get_vector()
return self
def __neg__(self):
return -self.get_vector()
[docs] def dot(self, other):
""" Dot product with another array. """
M = Matrix(1, len(self), self.value_type, address=self.address)
return M.dot(other)
[docs] def shuffle(self):
""" Insecure shuffle in place. """
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
[docs] def secure_shuffle(self):
""" Secure shuffle in place according to the security model.
See :py:func:`MultiArray.secure_shuffle` for references. """
self.assign_vector(self.get_vector().secure_shuffle())
[docs] def secure_permute(self, *args, **kwargs):
""" Secure permute in place according to the security model.
See :py:func:`MultiArray.secure_shuffle` for references.
:param permutation: output of :py:func:`sint.get_secure_shuffle()`
:param reverse: whether to apply inverse (default: False)
"""
self.assign_vector(self.get_vector().secure_permute(*args, **kwargs))
[docs] def randomize(self, *args):
""" Randomize array according to data type.
If it is :py:class:`sfix`, the following will sample an
individual uniformly random entry of the array
:py:obj:`M` roughly in the range :math:`[a,b]`::
M.randomize(a, b)
"""
self.assign_vector(self.value_type.get_random(*args, size=len(self)))
[docs] def reveal(self):
""" Reveal the whole array.