Source code for Compiler.circuit

"""
This module contains functionality using circuits in the so-called
`Bristol Fashion`_ format. You can download a few examples including
the ones used below into ``Programs/Circuits`` as follows::

    make Programs/Circuits

.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC

"""
import math

from Compiler.GC.types import *
from Compiler.library import function_block, get_tape
from Compiler import util
import itertools
import struct

[docs]class Circuit: """ Use a Bristol Fashion circuit in a high-level program. The following example adds signed 64-bit inputs from two different parties and prints the result:: from circuit import Circuit sb64 = sbits.get_type(64) adder = Circuit('adder64') a, b = [sbitvec(sb64.get_input_from(i)) for i in (0, 1)] print_ln('%s', adder(a, b).elements()[0].reveal()) Circuits can also be executed in parallel as the following example shows:: from circuit import Circuit sb128 = sbits.get_type(128) key = sb128(0x2b7e151628aed2a6abf7158809cf4f3c) plaintext = sb128(0x6bc1bee22e409f96e93d7e117393172a) n = 1000 aes128 = Circuit('aes_128') ciphertexts = aes128(sbitvec([key] * n), sbitvec([plaintext] * n)) ciphertexts.elements()[n - 1].reveal().print_reg() This executes AES-128 1000 times in parallel and then outputs the last result, which should be ``0x3ad77bb40d7a3660a89ecaf32466ef97``, one of the test vectors for AES-128. """ def __init__(self, name): self.filename = 'Programs/Circuits/%s.txt' % name f = open(self.filename) self.functions = {} def __call__(self, *inputs): return self.run(*inputs) def run(self, *inputs): n = inputs[0][0].n, get_tape() if n not in self.functions: self.functions[n] = function_block(lambda *args: self.compile(*args)) flat_res = self.functions[n](*itertools.chain(*inputs)) res = [] i = 0 for l in self.n_output_wires: v = [] for j in range(l): v.append(flat_res[i]) i += 1 res.append(sbitvec.from_vec(v)) return util.untuplify(res) def compile(self, *all_inputs): f = open(self.filename) lines = iter(f) next_line = lambda: next(lines).split() n_gates, n_wires = (int(x) for x in next_line()) self.n_wires = n_wires input_line = [int(x) for x in next_line()] n_inputs = input_line[0] n_input_wires = input_line[1:] assert(n_inputs == len(n_input_wires)) inputs = [] s = 0 for n in n_input_wires: inputs.append(all_inputs[s:s + n]) s += n output_line = [int(x) for x in next_line()] n_outputs = output_line[0] self.n_output_wires = output_line[1:] assert(n_outputs == len(self.n_output_wires)) next(lines) wires = [None] * n_wires self.wires = wires i_wire = 0 for input, input_wires in zip(inputs, n_input_wires): assert(len(input) == input_wires) for i, reg in enumerate(input): wires[i_wire] = reg i_wire += 1 for i in range(n_gates): line = next_line() t = line[-1] if t in ('XOR', 'AND'): assert line[0] == '2' assert line[1] == '1' assert len(line) == 6 ins = [wires[int(line[2 + i])] for i in range(2)] if t == 'XOR': wires[int(line[4])] = ins[0] ^ ins[1] else: wires[int(line[4])] = ins[0] & ins[1] elif t == 'INV': assert line[0] == '1' assert line[1] == '1' assert len(line) == 5 wires[int(line[3])] = ~wires[int(line[2])] return self.wires[-sum(self.n_output_wires):]
Keccak_f = None
[docs]def sha3_256(x): """ This function implements SHA3-256 for inputs of up to 1080 bits:: from circuit import sha3_256 a = sbitvec.from_vec([]) b = sbitvec.from_hex('cc') c = sbitvec.from_hex('41fb') d = sbitvec.from_hex('1f877c') e = sbitvec.from_vec([sbit(0)] * 8) f = sbitvec.from_hex('41fb6834928423874832892983984728289238949827929283743858382828372f17188141fb6834928423874832892983984728289238949827929283743858382828372f17188141fb6834928423874832892983984728289238949827') g = sbitvec.from_hex('41fb6834928423874832892983984728289238949827929283743858382828372f17188141fb6834928423874832892983984728289238949827929283743858382828372f17188141fb6834928423874832892983984728289238949827929283743858382828372f17188141fb6834928423874832892983984728289238949827929283743858382828372f171881') h = sbitvec.from_vec([sbit(0)] * 3000) for x in a, b, c, d, e, f, g, h: sha3_256(x).reveal_print_hex() This should output the `test vectors <https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/ShortMsgKAT_SHA3-256.txt>`_ of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the 0 byte:: Reg[0] = 0xa7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a # Reg[0] = 0x677035391cd3701293d385f037ba32796252bb7ce180b00b582dd9b20aaad7f0 # Reg[0] = 0x39f31b6e653dfcd9caed2602fd87f61b6254f581312fb6eeec4d7148fa2e72aa # Reg[0] = 0xbc22345e4bd3f792a341cf18ac0789f1c9c966712a501b19d1b6632ccd408ec5 # Reg[0] = 0x5d53469f20fef4f8eab52b88044ede69c77a6a68a60728609fc4a65ff531e7d0 # Reg[0] = 0xf5f673ec50d662039871fd53fae3ced069baf09030132d6d60d2ba7040b02b18 # Reg[0] = 0xa8a42e808f9dc0f43366d5de91511f42e9c3f8f37de0307f010bf629401edd2a # Reg[0] = 0xf722631013ecacd42b4c7259e9fe22b8c81a86e9fe0d4a626800e7f50c5a8978 # """ global Keccak_f if Keccak_f is None: # only one instance Keccak_f = Circuit('Keccak_f') # whole bytes assert len(x.v) % 8 == 0 # rate r = 1088 # round up to be multiple of rate length_with_suffix = len(x.v) + 8 # to handle the case the fixed padding overflows the block n_blocks = max(math.ceil(length_with_suffix / r), 1) upper_block_length = n_blocks * r if x.v: n = x.v[0].n else: n = 1 d = sbitvec([sbits.get_type(8)(0x06)] * n) sbn = sbits.get_type(n) padding = [sbn(0)] * (upper_block_length - 8 - len(x.v)) P_flat = x.v + d.v + padding assert len(P_flat) == upper_block_length P_flat[-1] = ~P_flat[-1] # set last bit to 1 def flatten(S): res = [None] * 1600 for y in range(5): for x in range(5): for i in range(w): j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8 res[1600 - 1 - j] = S[x][y][i] return res def unflatten(S_flat): res = [[[None] * w for j in range(5)] for i in range(5)] for y in range(5): for x in range(5): for i in range(w): j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8 res[x][y][i] = S_flat[1600 - 1 -j] return res w = 64 # Initial state S = [[[sbn(0) for i in range(w)] for i in range(5)] for i in range(5)] def insert_block(local_S, local_P): assert len(local_P) == r P1 = [local_P[i * w:(i + 1) * w] for i in range(r // w)] for x in range(5): for y in range(5): if x + 5 * y < r // w: for i in range(w): local_S[x][y][i] ^= P1[x + 5 * y][i] for block_id in range(n_blocks): block = P_flat[block_id * r:(block_id + 1) * r] insert_block(S, block) S = unflatten(Keccak_f(flatten(S))) Z = [] while len(Z) <= 256: for y in range(5): for x in range(5): if x + 5 * y < r // w: Z += S[x][y] if len(Z) <= 256: S = unflatten(Keccak_f(flatten(S))) return sbitvec.from_vec(Z[:256])
[docs]class ieee_float: """ This gives access IEEE754 floating-point operations using Bristol Fashion circuits. The following example computes the standard deviation of 10 integers input by each of party 0 and 1:: from circuit import ieee_float values = [] for i in range(2): for j in range(10): values.append(sbitint.get_type(64).get_input_from(i)) fvalues = [ieee_float(x) for x in values] avg = sum(fvalues) / ieee_float(len(fvalues)) var = sum(x * x for x in fvalues) / ieee_float(len(fvalues)) - avg * avg stddev = var.sqrt() print_ln('avg: %s', avg.reveal()) print_ln('var: %s', var.reveal()) print_ln('stddev: %s', stddev.reveal()) """ _circuits = {} is_clear = False @classmethod def circuit(cls, name): if name not in cls._circuits: cls._circuits[name] = Circuit('FP-' + name) return cls._circuits[name] def __init__(self, value): if isinstance(value, (sbitint, sbitintvec)): self.value = self.circuit('i2f')(sbitvec.conv(value)) elif isinstance(value, sbitvec): self.value = value elif util.is_constant_float(value): self.value = sbitvec(sbits.get_type(64)( struct.unpack('Q', struct.pack('d', value))[0])) else: raise Exception('cannot convert type %s' % type(value)) def __add__(self, other): return ieee_float(self.circuit('add')(self.value, other.value)) def __radd__(self, other): if util.is_zero(other): return self else: return NotImplemented def __neg__(self): v = self.value.v[:] v[-1] = ~v[-1] return ieee_float(sbitvec.from_vec(v)) def __sub__(self, other): return self + -other def __mul__(self, other): return ieee_float(self.circuit('mul')(self.value, other.value)) def __truediv__(self, other): return ieee_float(self.circuit('div')(self.value, other.value)) def __eq__(self, other): res = sbitvec.from_vec(self.circuit('eq')(self.value, other.value).v[:1]) if res.v[0].n == 1: return res.elements()[0] else: return res def sqrt(self): return ieee_float(self.circuit('sqrt')(self.value)) def to_int(self): res = sbitintvec.from_vec(self.circuit('f2i')(self.value)) if res.v[0].n == 1: return res.elements()[0] else: return res def reveal(self): assert self.value.v[0].n == 1 m = self.value.v[:52] e = self.value.v[52:63] s = [self.value.v[63]] m, e, s = [sbitvec.from_vec(x).elements()[0].reveal() for x in (m, e, s)] return cbitfloat(2 ** 52 + m, e - 2 ** 10 - 51, cbit((m.to_regint() == 0) * (e.to_regint() == 0)), s, (e.to_regint() == 2 ** 11 - 1))