Source code for Compiler.program

"""
This module contains the building blocks of the compiler such as code
blocks and registers. Most relevant is the central :py:class:`Program`
object that holds various properties of the computation.
"""

import inspect
import itertools
import math
import os
import re
import sys
import hashlib
import random
from collections import defaultdict, deque
from functools import reduce

import Compiler.instructions
import Compiler.instructions_base
import Compiler.instructions_base as inst_base
from Compiler.config import REG_MAX, USER_MEM, COST
from Compiler.exceptions import CompilerError
from Compiler.instructions_base import RegType

from . import allocator as al
from . import util

data_types = dict(
    triple=0,
    square=1,
    bit=2,
    inverse=3,
    dabit=4,
    mixed=5,
    random=6,
    open=7,
)

field_types = dict(
    modp=0,
    gf2n=1,
    bit=2,
)


class defaults:
    debug = False
    verbose = False
    outfile = None
    ring = 0
    field = 0
    binary = 0
    garbled = False
    prime = None
    galois = 40
    budget = 1000
    mixed = False
    edabit = False
    invperm = False
    split = None
    cisc = True
    comparison = None
    merge_opens = True
    preserve_mem_order = False
    max_parallel_open = 0
    dead_code_elimination = False
    noreallocate = False
    asmoutfile = None
    stop = False
    insecure = False
    keep_cisc = False


[docs]class Program(object): """A program consists of a list of tapes representing the whole computation. When compiling an :file:`.mpc` file, the single instance is available as :py:obj:`program`. When compiling directly from Python code, an instance has to be created before running any instructions. """ def __init__(self, args, options=defaults, name=None): from .non_linear import KnownPrime, Prime self.options = options self.verbose = options.verbose self.args = args self.name = name self.init_names(args) self._security = 40 self.used_security = 0 self.prime = None self.tapes = [] if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1: raise CompilerError("can only use one out of -B, -R, -F") if options.prime and (options.ring or options.binary): raise CompilerError("can only use one out of -B, -R, -p") if options.ring: self.set_ring_size(int(options.ring)) else: self.bit_length = int(options.binary) or int(options.field) if options.prime: self.prime = int(options.prime) max_bit_length = int(options.prime).bit_length() - 2 if self.bit_length > max_bit_length: raise CompilerError( "integer bit length can be maximal %s" % max_bit_length ) self.bit_length = self.bit_length or max_bit_length self.non_linear = KnownPrime(self.prime) else: self.non_linear = Prime(self.security) if not self.bit_length: self.bit_length = 64 print("Default bit length for compilation:", self.bit_length) if not (options.binary or options.garbled): print("Default security parameter for compilation:", self._security) self.galois_length = int(options.galois) if self.verbose: print("Galois length:", self.galois_length) self.tape_counter = 0 self._curr_tape = None self.DEBUG = options.debug self.allocated_mem = RegType.create_dict(lambda: USER_MEM) self.free_mem_blocks = defaultdict(al.BlockAllocator) self.later_mem_blocks = defaultdict(list) self.allocated_mem_blocks = {} self.saved = 0 self.req_num = None self.tape_stack = [] self.n_threads = 1 self.public_input_file = None self.types = {} if self.options.budget: self.budget = int(self.options.budget) else: if self.options.optimize_hard: self.budget = 100000 else: self.budget = defaults.budget self.to_merge = [ Compiler.instructions.asm_open_class, Compiler.instructions.gasm_open_class, Compiler.instructions.muls_class, Compiler.instructions.gmuls_class, Compiler.instructions.mulrs_class, Compiler.instructions.gmulrs, Compiler.instructions.dotprods_class, Compiler.instructions.gdotprods_class, Compiler.instructions.asm_input_class, Compiler.instructions.gasm_input_class, Compiler.instructions.inputfix_class, Compiler.instructions.inputfloat_class, Compiler.instructions.inputmixed_class, Compiler.instructions.trunc_pr_class, Compiler.instructions_base.Mergeable, ] import Compiler.GC.instructions as gc self.to_merge += [ gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, gc.stmsdci, gc.andrs, gc.ands, gc.inputb, gc.inputbvec, gc.reveal, ] self.use_trunc_pr = False """ Setting whether to use special probabilistic truncation. """ self.use_dabit = options.mixed """ Setting whether to use daBits for non-linear functionality. """ self._edabit = options.edabit """ Whether to use the low-level INVPERM instruction (only implemented with the assumption of a semi-honest two-party environment)""" self._invperm = options.invperm self._split = False if options.split: self.use_split(int(options.split)) self._square = False self._always_raw = False self._linear_rounds = False self.warn_about_mem = [True] self.relevant_opts = set() self.n_running_threads = None self.input_files = {} self.base_addresses = util.dict_by_id() self._protect_memory = False self.mem_protect_stack = [] self._always_active = True self.active = True self.prevent_breaks = False self.cisc_to_function = True if not self.options.cisc: self.options.cisc = not self.options.optimize_hard Program.prog = self from . import comparison, instructions, instructions_base, types instructions.program = self instructions_base.program = self types.program = self comparison.program = self comparison.set_variant(options) def get_args(self): return self.args def max_par_tapes(self): """Upper bound on number of tapes that will be run in parallel. (Excludes empty tapes)""" return self.n_threads def init_names(self, args): self.programs_dir = "Programs" if self.verbose: print("Compiling program in", self.programs_dir) for dirname in (self.programs_dir, "Player-Data"): if not os.path.exists(dirname): os.mkdir(dirname) # create extra directories if needed for dirname in ["Public-Input", "Bytecode", "Schedules"]: if not os.path.exists(self.programs_dir + "/" + dirname): os.mkdir(self.programs_dir + "/" + dirname) if self.name is None: self.name = args[0].split("/")[-1] exts = ".mpc", ".py" for ext in exts: if self.name.endswith(ext): self.name = self.name[:-len(ext)] infiles = [args[0]] for x in (self.programs_dir, sys.path[0] + "/Programs"): for ext in exts: filename = args[0] if not filename.endswith(ext): filename += ext filename = x + "/Source/" + filename if os.path.abspath(filename) not in \ [os.path.abspath(f) for f in infiles]: infiles += [filename] existing = [f for f in infiles if os.path.exists(f)] if len(existing) == 1: self.infile = existing[0] elif len(existing) > 1: raise CompilerError("ambiguous input files: " + ", ".join(existing)) else: raise CompilerError( "found none of the potential input files: " + ", ".join("'%s'" % x for x in infiles)) """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames """ if self.options.outfile: self.name = self.options.outfile + "-" + self.name else: self.name = self.name if len(args) > 1: self.name += "-" + "-".join(re.sub("/", "_", arg) for arg in args[1:]) def set_ring_size(self, ring_size): from .non_linear import Ring for tape in self.tapes: prev = tape.req_bit_length["p"] if prev and prev != ring_size: raise CompilerError("cannot have different ring sizes") self.bit_length = ring_size - 1 self.non_linear = Ring(ring_size) self.options.ring = str(ring_size)
[docs] def new_tape(self, function, args=[], name=None, single_thread=False): """ Create a new tape from a function. See :py:func:`~Compiler.library.multithread` and :py:func:`~Compiler.library.for_range_opt_multithread` for easier-to-use higher-level functionality. The following runs two threads defined by two different functions:: def f(): ... def g(): ... tapes = [program.new_tape(x) for x in (f, g)] thread_numbers = program.run_tapes(tapes) program.join_tapes(threads_numbers) :param function: Python function defining the thread :param args: arguments to the function :param name: name used for files :param single_thread: Boolean indicating whether tape will never be run in parallel to itself :returns: tape handle """ if name is None: name = function.__name__ name = "%s-%s" % (self.name, name) # make sure there is a current tape self.curr_tape tape_index = len(self.tapes) self.tape_stack.append(self.curr_tape) self.curr_tape = Tape(name, self) self.curr_tape.singular = single_thread self.tapes.append(self.curr_tape) function(*args) self.finalize_tape(self.curr_tape) if self.tape_stack: self.curr_tape = self.tape_stack.pop() return tape_index
def run_tape(self, tape_index, arg): return self.run_tapes([[tape_index, arg]])[0]
[docs] def run_tapes(self, args): """Run tapes in parallel. See :py:func:`new_tape` for an example. :param args: list of tape handles or tuples of tape handle and extra argument (for :py:func:`~Compiler.library.get_arg`) :returns: list of thread numbers """ if not self.curr_tape.singular: raise CompilerError( "Compiler does not support " "recursive spawning of threads" ) args = [list(util.tuplify(arg)) for arg in args] singular_tapes = set() for arg in args: if self.tapes[arg[0]].singular: if arg[0] in singular_tapes: raise CompilerError("cannot run singular tape in parallel") singular_tapes.add(arg[0]) assert len(arg) assert len(arg) <= 2 if len(arg) == 1: arg += [0] thread_numbers = [] while len(thread_numbers) < len(args): free_threads = self.curr_tape.free_threads if free_threads: thread_numbers.append(min(free_threads)) free_threads.remove(thread_numbers[-1]) else: thread_numbers.append(self.n_threads) self.n_threads += 1 self.curr_tape.start_new_basicblock(name="pre-run_tape") Compiler.instructions.run_tape( *sum(([x] + list(y) for x, y in zip(thread_numbers, args)), []) ) self.curr_tape.start_new_basicblock(name="post-run_tape") for arg in args: self.curr_block.req_node.children.append( self.tapes[arg[0]].req_tree) return thread_numbers
def join_tape(self, thread_number): self.join_tapes([thread_number])
[docs] def join_tapes(self, thread_numbers): """Wait for completion of tapes. See :py:func:`new_tape` for an example. :param thread_numbers: list of thread numbers """ self.curr_tape.start_new_basicblock(name="pre-join_tape") for thread_number in thread_numbers: Compiler.instructions.join_tape(thread_number) self.curr_tape.free_threads.add(thread_number) self.curr_tape.start_new_basicblock(name="post-join_tape")
def update_req(self, tape): if self.req_num is None: self.req_num = tape.req_num else: self.req_num += tape.req_num def write_bytes(self): """Write all non-empty threads and schedule to files.""" nonempty_tapes = [t for t in self.tapes] sch_filename = self.programs_dir + "/Schedules/%s.sch" % self.name sch_file = open(sch_filename, "w") print("Writing to", sch_filename) sch_file.write(str(self.max_par_tapes()) + "\n") sch_file.write(str(len(nonempty_tapes)) + "\n") sch_file.write(" ".join("%s:%d" % (tape.name, len(tape)) for tape in nonempty_tapes) + "\n") sch_file.write("1 0\n") sch_file.write("0\n") sch_file.write(" ".join(sys.argv) + "\n") req = max(x.req_bit_length["p"] for x in self.tapes) if self.options.ring: sch_file.write("R:%s" % self.options.ring) elif self.options.prime: sch_file.write("p:%s" % self.options.prime) else: sch_file.write("lgp:%s" % req) sch_file.write("\n") sch_file.write("opts: %s\n" % " ".join(self.relevant_opts)) sch_file.write("sec:%d\n" % self.used_security) sch_file.close() h = hashlib.sha256() for tape in self.tapes: tape.write_bytes() h.update(tape.hash) print('Hash:', h.hexdigest()) def finalize_tape(self, tape): if not tape.purged: tape.optimize(self.options) tape.write_bytes() if self.options.asmoutfile: tape.write_str(self.options.asmoutfile + "-" + tape.name) tape.purge() @property def curr_tape(self): """The tape that is currently running.""" if self._curr_tape is None: assert not self.tapes self._curr_tape = Tape(self.name, self) self.tapes.append(self._curr_tape) return self._curr_tape @curr_tape.setter def curr_tape(self, value): self._curr_tape = value @property def curr_block(self): """The basic block that is currently being created.""" return self.curr_tape.active_basicblock def malloc(self, size, mem_type, reg_type=None, creator_tape=None, use_freed=True): """Allocate memory from the top""" if not isinstance(size, int): raise CompilerError("size must be known at compile time") if size == 0: return if isinstance(mem_type, type): try: size *= math.ceil(mem_type.n / mem_type.unit) except AttributeError: pass self.types[mem_type.reg_type] = mem_type mem_type = mem_type.reg_type elif reg_type is not None: self.types[mem_type] = reg_type single_size = None if not (creator_tape or self.curr_tape).singular: if self.n_running_threads: single_size = size size *= self.n_running_threads else: raise CompilerError("cannot allocate memory " "outside main thread") blocks = self.free_mem_blocks[mem_type] addr = blocks.pop(size) if use_freed else None if addr is not None: self.saved += size else: addr = self.allocated_mem[mem_type] self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) if addr + size >= 2**64: raise CompilerError("allocation exceeded for type '%s'" % mem_type) self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool if single_size: from .library import get_thread_number, runtime_error_if bak = self.curr_tape.active_basicblock self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0] tn = get_thread_number() runtime_error_if(tn > self.n_running_threads, "malloc") res = addr + single_size * (tn - 1) self.curr_tape.active_basicblock = bak self.base_addresses[res] = addr return res else: return addr def free(self, addr, mem_type): """Free memory""" now = True if not util.is_constant(addr): addr = self.base_addresses[addr] now = self.curr_tape == self.tapes[0] size, pool = self.allocated_mem_blocks[addr, mem_type] if self.curr_block.alloc_pool is not pool: raise CompilerError("Cannot free memory across function blocks") self.allocated_mem_blocks.pop((addr, mem_type)) if now: self.free_mem_blocks[mem_type].push(addr, size) else: self.later_mem_blocks[mem_type].append((addr, size)) def free_later(self): for mem_type in self.later_mem_blocks: for block in self.later_mem_blocks[mem_type]: self.free_mem_blocks[mem_type].push(*block) self.later_mem_blocks.clear() def finalize(self): # optimize the tapes for tape in self.tapes: tape.optimize(self.options) if self.tapes: self.update_req(self.curr_tape) # finalize the memory self.finalize_memory() # communicate protocol compability Compiler.instructions.active(self._always_active) self.write_bytes() if self.options.asmoutfile: for tape in self.tapes: tape.write_str(self.options.asmoutfile + "-" + tape.name) # Making sure that the public_input_file has been properly closed if self.public_input_file is not None: self.public_input_file.close() def finalize_memory(self): self.curr_tape.start_new_basicblock(None, "memory-usage", req_node=self.curr_tape.req_tree) # reset register counter to 0 if not self.options.noreallocate: self.curr_tape.init_registers() for mem_type, size in sorted(self.allocated_mem.items()): if size and (not self.options.garbled or \ mem_type not in ('s', 'sg', 'c', 'cg')): # print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: self.types[mem_type].load_mem(size - 1, mem_type) else: from Compiler.types import _get_type _get_type(mem_type).load_mem(size - 1, mem_type) if self.verbose: if self.saved: print("Saved %s memory units through reallocation" % self.saved)
[docs] def public_input(self, x): """Append a value to the public input file.""" if self.public_input_file is None: self.public_input_file = open( self.programs_dir + "/Public-Input/%s" % self.name, "w" ) self.public_input_file.write("%s\n" % str(x))
def get_binary_input_file(self, player): key = player, 'bin' if key not in self.input_files: filename = 'Player-Data/Input-Binary-P%d-0' % player print('Writing binary data to', filename) self.input_files[key] = open(filename, 'wb') return self.input_files[key]
[docs] def set_bit_length(self, bit_length): """Change the integer bit length for non-linear functions.""" self.bit_length = bit_length print("Changed bit length for comparisons etc. to", bit_length)
def set_security(self, security): changed = self._security != security self._security = security self.non_linear.set_security(security) if changed: print("Changed statistical security for comparison etc. to", security) @property def security(self): """The statistical security parameter for non-linear functions.""" self.used_security = max(self.used_security, self._security) return self._security @security.setter def security(self, security): self.set_security(security) def optimize_for_gc(self): import Compiler.GC.instructions as gc self.to_merge += [gc.xors] def get_tape_counter(self): res = self.tape_counter self.tape_counter += 1 return res @property def use_trunc_pr(self): if not self._use_trunc_pr: self.relevant_opts.add("trunc_pr") return self._use_trunc_pr @use_trunc_pr.setter def use_trunc_pr(self, change): self._use_trunc_pr = change
[docs] def use_edabit(self, change=None): """Setting whether to use edaBits for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` :returns: setting if :py:obj:`change` is :py:obj:`None` """ if change is None: if not self._edabit: self.relevant_opts.add("edabit") return self._edabit else: self._edabit = change
[docs] def use_invperm(self, change=None): """ Set whether to use the low-level INVPERM instruction to inverse a permutation (see sint.inverse_permutation). The INVPERM instruction assumes a semi-honest two-party environment. If false, a general protocol implemented in the high-level language is used. :param change: change setting if not :py:obj:`None` :returns: setting if :py:obj:`change` is :py:obj:`None` """ if change is None: if not self._invperm: self.relevant_opts.add("invperm") return self._invperm else: self._invperm = change
def use_edabit_for(self, *args): return True
[docs] def use_split(self, change=None): """Setting whether to use local arithmetic-binary share conversion for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` :returns: setting if :py:obj:`change` is :py:obj:`None` """ if change is None: if not self._split: self.relevant_opts.add("split") return self._split else: if change and not self.options.ring: raise CompilerError("splitting only supported for rings") assert change > 1 or change is False self._split = change
[docs] def use_square(self, change=None): """Setting whether to use preprocessed square tuples (default: false). :param change: change setting if not :py:obj:`None` :returns: setting if :py:obj:`change` is :py:obj:`None` """ if change is None: return self._square else: self._square = change
def always_raw(self, change=None): if change is None: return self._always_raw else: self._always_raw = change def linear_rounds(self, change=None): if change is None: return self._linear_rounds else: self._linear_rounds = change
[docs] def options_from_args(self): """Set a number of options from the command-line arguments.""" if "trunc_pr" in self.args: self.use_trunc_pr = True if "signed_trunc_pr" in self.args: self.use_trunc_pr = -1 if "split" in self.args or "split3" in self.args: self.use_split(3) for arg in self.args: m = re.match("split([0-9]+)", arg) if m: self.use_split(int(m.group(1))) if "raw" in self.args: self.always_raw(True) if "edabit" in self.args: self.use_edabit(True) if "invperm" in self.args: self.use_invperm(True) if "linear_rounds" in self.args: self.linear_rounds(True)
def disable_memory_warnings(self): self.warn_about_mem.append(False) self.curr_block.warn_about_mem = False
[docs] def protect_memory(self, status): """ Enable or disable memory protection. """ self._protect_memory = status
def open_memory_scope(self, key=None): self.mem_protect_stack.append(self._protect_memory) self.protect_memory(key or object()) def close_memory_scope(self): self.protect_memory(self.mem_protect_stack.pop()) def use_cisc(self): return self.options.cisc and (not self.prime or self.rabbit_gap()) \ and not self.options.max_parallel_open def rabbit_gap(self): assert self.prime p = self.prime logp = int(round(math.log(p, 2))) return abs(p - 2 ** logp) / p < 2 ** -self.security @property def active(self): """ Whether to use actively secure protocols. """ return self._active @active.setter def active(self, change): self._always_active &= change self._active = change def semi_honest(self): self._always_active = False @staticmethod def read_schedule(schedule): m = re.search(r"([^/]*)\.mpc", schedule) if m: schedule = m.group(1) if not os.path.exists(schedule): schedule = "Programs/Schedules/%s.sch" % schedule try: return open(schedule).readlines() except FileNotFoundError: print( "%s not found, have you compiled the program?" % schedule, file=sys.stderr, ) sys.exit(1) @classmethod def read_tapes(cls, schedule): lines = cls.read_schedule(schedule) for tapename in lines[2].split(" "): yield tapename.strip().split(":")[0] @classmethod def read_n_threads(cls, schedule): return int(cls.read_schedule(schedule)[0]) @classmethod def read_domain_size(cls, schedule): from Compiler.instructions import reqbl_class tapename = cls.read_schedule(schedule)[2].strip().split(":")[0] for inst in Tape.read_instructions(tapename): if inst.code == reqbl_class.code: bl = inst.args[0] return (abs(bl.i) + 63) // 64 * 8
class Tape: """A tape contains a list of basic blocks, onto which instructions are added.""" def __init__(self, name, program): """Set prime p and the initial instructions and registers.""" self.program = program name += "-%d" % program.get_tape_counter() self.init_names(name) self.init_registers() self.req_tree = self.ReqNode(name) self.basicblocks = [] self.purged = False self.block_counter = 0 self.active_basicblock = None self.old_allocated_mem = program.allocated_mem.copy() self.start_new_basicblock(req_node=self.req_tree) self._is_empty = False self.merge_opens = True self.if_states = [] self.req_bit_length = defaultdict(lambda: 0) self.function_basicblocks = {} self.functions = [] self.singular = True self.free_threads = set() self.loop_breaks = [] self.warned_about_mem = False class BasicBlock(object): def __init__(self, parent, name, scope, exit_condition=None, req_node=None): self.parent = parent self.instructions = [] self.name = name self.open_queue = [] self.exit_condition = exit_condition self.exit_block = None self.previous_block = None self.scope = scope self.children = [] if scope is not None: scope.children.append(self) self.alloc_pool = scope.alloc_pool else: self.alloc_pool = al.AllocPool() self.purged = False self.n_rounds = 0 self.n_to_merge = 0 self.rounds = Tape.ReqNum() self.warn_about_mem = parent.program.warn_about_mem[-1] self.req_node = req_node self.used_from_scope = set() def __len__(self): return len(self.instructions) def new_reg(self, reg_type, size=None): return self.parent.new_reg(reg_type, size=size) def set_return(self, previous_block, sub_block): self.previous_block = previous_block self.sub_block = sub_block def adjust_return(self): offset = self.sub_block.get_offset(self) self.previous_block.return_address_store.args[1] = offset def set_exit(self, condition, exit_true=None): """Sets the block which we start from next, depending on the condition. (Default is to go to next block in the list) """ self.exit_condition = condition self.exit_block = exit_true for reg in condition.get_used(): reg.can_eliminate = False def add_jump(self): """Add the jump for this block's exit condition to list of instructions (must be done after merging)""" self.instructions.append(self.exit_condition) def get_offset(self, next_block): return next_block.offset - (self.offset + len(self.instructions)) def adjust_jump(self): """Set the correct relative jump offset""" offset = self.get_offset(self.exit_block) self.exit_condition.set_relative_jump(offset) def purge(self, retain_usage=True): def relevant(inst): req_node = Tape.ReqNode("") req_node.num = Tape.ReqNum() inst.add_usage(req_node) return req_node.num != {} if retain_usage: self.usage_instructions = list(filter(relevant, self.instructions)) else: self.usage_instructions = [] if len(self.usage_instructions) > 1000: print("Retaining %d instructions" % len(self.usage_instructions)) del self.instructions self.purged = True def add_usage(self, req_node): if self.purged: instructions = self.usage_instructions else: instructions = self.instructions for inst in instructions: inst.add_usage(req_node) req_node.num["all", "round"] += self.n_rounds req_node.num["all", "inv"] += self.n_to_merge req_node.num += self.rounds def expand_cisc(self): if self.parent.program.options.keep_cisc is not None: skip = ["LTZ", "Trunc", "EQZ"] skip += self.parent.program.options.keep_cisc.split(",") else: skip = [] tape = self.parent tape.start_new_basicblock(scope=self.scope, req_node=self.req_node, name="cisc") start_block = tape.basicblocks[-1] start_block.alloc_pool = self.alloc_pool for inst in self.instructions: inst.expand_merged(skip) self.instructions = tape.active_basicblock.instructions if start_block == tape.basicblocks[-1]: res = self else: res = start_block tape.basicblocks[-1] = self return res def __str__(self): return self.name def is_empty(self): """Returns True if the list of basic blocks is empty. Note: False is returned even when tape only contains basic blocks with no instructions. However, these are removed when optimize is called.""" if not self.purged: self._is_empty = len(self.basicblocks) == 0 return self._is_empty def start_new_basicblock(self, scope=False, name="", req_node=None): assert not self.program.prevent_breaks if self.program.verbose and self.active_basicblock and \ self.program.allocated_mem != self.old_allocated_mem: print("New allocated memory in %s " % self.active_basicblock.name, end="") for t, n in self.program.allocated_mem.items(): if n != self.old_allocated_mem[t]: print("%s:%d " % (t, n - self.old_allocated_mem[t]), end="") print() self.old_allocated_mem = self.program.allocated_mem.copy() # use False because None means no scope if scope is False: scope = self.active_basicblock suffix = "%s-%d" % (name, self.block_counter) self.block_counter += 1 if req_node is None: req_node = self.active_basicblock.req_node sub = self.BasicBlock(self, self.name + "-" + suffix, scope, req_node=req_node) self.basicblocks.append(sub) self.active_basicblock = sub # print 'Compiling basic block', sub.name def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) def init_names(self, name): self.name = name self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc" def __len__(self): if self.purged: return self.size else: return sum(len(block) for block in self.basicblocks) def purge(self): self.size = len(self) for block in self.basicblocks: block.purge() self._is_empty = len(self.basicblocks) == 0 del self.basicblocks del self.active_basicblock self.purged = True def unpurged(function): def wrapper(self, *args, **kwargs): if self.purged: return return function(self, *args, **kwargs) return wrapper @unpurged def optimize(self, options): if len(self.basicblocks) == 0: print("Tape %s is empty" % self.name) return if self.if_states: print("Tracebacks for open blocks:") for state in self.if_states: try: print(util.format_trace(state.caller)) except AttributeError: pass print() raise CompilerError("Unclosed if/else blocks, see tracebacks above") if self.program.verbose: print( "Processing tape", self.name, "with %d blocks" % len(self.basicblocks) ) for block in self.basicblocks: al.determine_scope(block, options) # merge open instructions # need to do this if there are several blocks if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: for i, block in enumerate(self.basicblocks): if len(block.instructions) > 0 and self.program.verbose: print( "Processing basic block %s, %d/%d, %d instructions" % ( block.name, i, len(self.basicblocks), len(block.instructions), ) ) # the next call is necessary for allocation later even without merging merger = al.Merger(block, options, tuple(self.program.to_merge)) if options.dead_code_elimination: if len(block.instructions) > 1000000: print("Eliminate dead code...") merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: block.used_from_scope = util.set_by_id() continue if len(block.instructions) > 1000000: print("Merging instructions...") numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) if options.verbose: block.rounds = merger.req_num if merger.counter and self.program.verbose: print( "Block requires", ", ".join( "%d %s" % (y, x.__name__) for x, y in list(merger.counter.items()) ), ) if merger.counter and self.program.verbose: print( "Block requires %s rounds" % ", ".join( "%d %s" % (y, x.__name__) for x, y in list(merger.rounds.items()) ) ) # free memory merger = None if options.dead_code_elimination: block.instructions = [ x for x in block.instructions if x is not None ] if not (options.merge_opens and self.merge_opens): print("Not merging instructions in tape %s" % self.name) if options.cisc: self.expand_cisc() # add jumps offset = 0 for block in self.basicblocks: if block.exit_condition is not None: block.add_jump() block.offset = offset offset += len(block.instructions) for block in self.basicblocks: if block.exit_block is not None: block.adjust_jump() if block.previous_block is not None: block.adjust_return() # now remove any empty blocks (must be done after setting jumps) self.basicblocks = [x for x in self.basicblocks if len(x.instructions) != 0] # allocate registers reg_counts = self.count_regs() if options.noreallocate: if self.program.verbose: print("Tape register usage:", dict(reg_counts)) else: if self.program.verbose: print("Tape register usage before re-allocation:", dict(reg_counts)) print( "modp: %d clear, %d secret" % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) ) print( "GF2N: %d clear, %d secret" % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) ) print("Re-allocating...") allocator = al.StraightlineAllocator(REG_MAX, self.program) # make addresses available in functions for addr in self.program.base_addresses: if addr.program == self and self.basicblocks: allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool) seen = set() def alloc(block): allocator.update_usage(block.alloc_pool) for reg in sorted( block.used_from_scope, key=lambda x: (x.reg_type, x.i) ): allocator.alloc_reg(reg, block.alloc_pool) seen.add(block) def alloc_loop(block): left = deque([block]) while left: block = left.popleft() alloc(block) for child in block.children: if child not in seen: left.append(child) allocator.old_pool = None for i, block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 1000000: print( "Allocating %s, %d/%d" % (block.name, i, len(self.basicblocks)) ) if block.exit_condition is not None: jump = block.exit_condition.get_relative_jump() if ( isinstance(jump, int) and jump < 0 and block.exit_block.scope is not None ): alloc_loop(block.exit_block.scope) usage = allocator.max_usage.copy() allocator.process(block.instructions, block.alloc_pool) if self.program.verbose and usage != allocator.max_usage: print("Allocated registers in %s " % block.name, end="") for t, n in allocator.max_usage.items(): if n > usage[t]: print("%s:%d " % (t, n - usage[t]), end="") print() allocator.finalize(options) if self.program.verbose: print("Tape register usage:", dict(allocator.max_usage)) scopes = set(block.alloc_pool for block in self.basicblocks) n_fragments = sum(scope.n_fragments() for scope in scopes) print("%d register fragments in %d scopes" % (n_fragments, len(scopes))) # offline data requirements if self.program.verbose: print("Compile offline data requirements...") for block in self.basicblocks: block.req_node.add_block(block) self.req_num = self.req_tree.aggregate() if self.program.verbose: print("Tape requires", self.req_num) for req, num in sorted(self.req_num.items()): if num == float("inf") or num >= 2**64: num = -1 if req[1] in data_types: self.basicblocks[-1].instructions.append( Compiler.instructions.use( field_types[req[0]], data_types[req[1]], num, add_to_prog=False ) ) elif req[1] == "input": self.basicblocks[-1].instructions.append( Compiler.instructions.use_inp( field_types[req[0]], req[2], num, add_to_prog=False ) ) elif req[0] == "modp": self.basicblocks[-1].instructions.append( Compiler.instructions.use_prep(req[1], num, add_to_prog=False) ) elif req[0] == "gf2n": self.basicblocks[-1].instructions.append( Compiler.instructions.guse_prep(req[1], num, add_to_prog=False) ) elif req[0] == "edabit": self.basicblocks[-1].instructions.append( Compiler.instructions.use_edabit( False, req[1], num, add_to_prog=False ) ) elif req[0] == "sedabit": self.basicblocks[-1].instructions.append( Compiler.instructions.use_edabit( True, req[1], num, add_to_prog=False ) ) elif req[0] == "matmul": self.basicblocks[-1].instructions.append( Compiler.instructions.use_matmul(*req[1], num, add_to_prog=False) ) if not self.is_empty(): # bit length requirement for x in ("p", "2"): if self.req_bit_length[x]: bl = self.req_bit_length[x] if self.program.options.ring: bl = -int(self.program.options.ring) self.basicblocks[-1].instructions.append( Compiler.instructions.reqbl(bl, add_to_prog=False) ) if self.program.verbose: print("Tape requires prime bit length", self.req_bit_length["p"]) print("Tape requires galois bit length", self.req_bit_length["2"]) @unpurged def expand_cisc(self): mapping = {None: None} blocks = self.basicblocks[:] self.basicblocks = [] for block in blocks: expanded = block.expand_cisc() mapping[block] = expanded for block in self.basicblocks: if block not in mapping: mapping[block] = block for block in self.basicblocks: block.exit_block = mapping[block.exit_block] if block.exit_block is not None: assert block.exit_block in self.basicblocks if block.previous_block and mapping[block] != block: mapping[block].previous_block = block.previous_block mapping[block].sub_block = block.sub_block block.previous_block = None del block.sub_block @unpurged def _get_instructions(self): return itertools.chain.from_iterable(b.instructions for b in self.basicblocks) @unpurged def get_encoding(self): """Get the encoding of the program, in human-readable format.""" return [i.get_encoding() for i in self._get_instructions() if i is not None] @unpurged def get_bytes(self): """Get the byte encoding of the program as an actual string of bytes.""" return b"".join( i.get_bytes() for i in self._get_instructions() if i is not None ) @unpurged def write_encoding(self, filename): """Write the readable encoding to a file.""" print("Writing to", filename) f = open(filename, "w") for line in self.get_encoding(): f.write(str(line) + "\n") f.close() @unpurged def write_str(self, filename): """Write the sequence of instructions to a file.""" print("Writing to", filename) f = open(filename, "w") n = 0 for block in self.basicblocks: if block.instructions: f.write("# %s\n" % block.name) for line in block.instructions: f.write("%s # %d\n" % (line, n)) n += 1 f.close() @unpurged def write_bytes(self, filename=None): """Write the program's byte encoding to a file.""" if filename is None: filename = self.outfile if not filename.endswith(".bc"): filename += ".bc" if "Bytecode" not in filename: filename = self.program.programs_dir + "/Bytecode/" + filename print("Writing to", filename) f = open(filename, "wb") h = hashlib.sha256() for i in self._get_instructions(): if i is not None: b = i.get_bytes() f.write(b) h.update(b) f.close() self.hash = h.digest() def new_reg(self, reg_type, size=None): return self.Register(reg_type, self, size=size) def count_regs(self, reg_type=None): if reg_type is None: return self.reg_counter else: return self.reg_counter[reg_type] def __str__(self): return self.name class ReqNum(defaultdict): def __init__(self, init={}): super(Tape.ReqNum, self).__init__(lambda: 0, init) def __add__(self, other): res = Tape.ReqNum() for i, count in list(self.items()): res[i] += count for i, count in list(other.items()): res[i] += count return res def __mul__(self, other): res = Tape.ReqNum() for i in self: res[i] = other * self[i] return res __rmul__ = __mul__ def set_all(self, value): if Program.prog.options.verbose and \ value == float("inf") and self["all", "inv"] > 0: print("Going to unknown from %s" % self) res = Tape.ReqNum() for i in self: res[i] = value return res def max(self, other): res = Tape.ReqNum() for i in self: res[i] = max(self[i], other[i]) for i in other: res[i] = max(self[i], other[i]) return res def cost(self): return sum( num * COST[req[0]][req[1]] for req, num in list(self.items()) if req[1] != "input" and req[0] != "edabit" ) def pretty(self): def t(x): return "integer" if x == "modp" else x def f(num): try: return "%12.0f" % num except: return str(num) res = [] for req, num in self.items(): domain = t(req[0]) if num < 0: num = float('inf') n = f(num) if req[1] == "input": res += ["%s %s inputs from player %d" % (n, domain, req[2])] elif domain.endswith("edabit"): if domain == "sedabit": eda = "strict edabits" else: eda = "loose edabits" res += ["%s %s of length %d" % (n, eda, req[1])] elif domain == "matmul": res += [ "%s matrix multiplications (%dx%d * %dx%d)" % (n, req[1][0], req[1][1], req[1][1], req[1][2]) ] elif req[0] != "all": res += ["%s %s %ss" % (n, domain, req[1])] if self["all", "round"]: res += ["%s virtual machine rounds" % f(self["all", "round"])] return res def __str__(self): return ", ".join(self.pretty()) def __repr__(self): return repr(dict(self)) class ReqNode(object): __slots__ = ["num", "_children", "name", "blocks", "aggregated"] def __init__(self, name): self._children = [] self.name = name self.blocks = [] self.aggregated = None @property def children(self): self.aggregated = None return self._children def aggregate(self, *args): if self.aggregated is not None: return self.aggregated self.num = Tape.ReqNum() for block in self.blocks: block.add_usage(self) res = reduce( lambda x, y: x + y.aggregate(self.name), self.children, self.num ) self.aggregated = res return res def increment(self, data_type, num=1): self.num[data_type] += num self.aggregated = None def add_block(self, block): self.blocks.append(block) self.aggregated = None class ReqChild(object): __slots__ = ["aggregator", "nodes", "parent"] def __init__(self, aggregator, parent): self.aggregator = aggregator self.nodes = [] self.parent = parent def aggregate(self, name): res = self.aggregator([node.aggregate() for node in self.nodes]) try: n_reps = self.aggregator([1]) n_rounds = res["all", "round"] n_invs = res["all", "inv"] if (n_invs / n_rounds) * 1000 < n_reps: print( self.nodes[0].blocks[0].name, "blowing up rounds: ", "(%d / %d) ** 3 < %d" % (n_rounds, n_reps, n_invs), ) except Exception: pass return res def add_node(self, tape, name): new_node = Tape.ReqNode(name) self.nodes.append(new_node) return new_node def open_scope(self, aggregator, scope=False, name=""): req_node = self.active_basicblock.req_node child = self.ReqChild(aggregator, req_node) req_node.children.append(child) node = child.add_node(self, "%s-%d" % (name, len(self.basicblocks))) self.start_new_basicblock(name=name, req_node=node) return child def close_scope(self, outer_scope, parent_req_node, name): self.start_new_basicblock(outer_scope, name, req_node=parent_req_node) def require_bit_length(self, bit_length, t="p"): if t == "p": if self.program.prime: if bit_length >= self.program.prime.bit_length() - 1: raise CompilerError( "required bit length %d too much for %d" % (bit_length, self.program.prime) ) self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t]) else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) @staticmethod def read_instructions(tapename): tape = open("Programs/Bytecode/%s.bc" % tapename, "rb") while tape.peek(): yield inst_base.ParsedInstruction(tape) class _no_truth(object): __slots__ = [] def __bool__(self): raise CompilerError( "Cannot derive truth value from register. " "This is a catch-all error appearing if you try to use a " "run-time value where the compiler expects a compile-time " "value, most likely a Python integer. " "In some cases, you can fix this by using 'compile.py -l'." ) class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. """ __slots__ = [ "reg_type", "program", "absolute_i", "relative_i", "size", "vector", "vectorbase", "caller", "can_eliminate", "duplicates", "block", ] maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): """Creates a new register. reg_type must be one of those defined in RegType.""" if Compiler.instructions_base.get_global_instruction_type() == "gf2n": if reg_type == RegType.ClearModp: reg_type = RegType.ClearGF2N elif reg_type == RegType.SecretModp: reg_type = RegType.SecretGF2N self.reg_type = reg_type self.program = program self.block = program.active_basicblock if size is None: size = Compiler.instructions_base.get_global_vector_size() if size is not None and size > self.maximum_size: raise CompilerError("vector too large: %d" % size) self.size = size self.vectorbase = self self.relative_i = 0 if i is not None: self.i = i elif size is not None: self.i = program.reg_counter[reg_type] program.reg_counter[reg_type] += size else: self.i = float("inf") self.vector = [] self.can_eliminate = True self.duplicates = util.set_by_id([self]) if Program.prog.DEBUG: self.caller = [frame[1:] for frame in inspect.stack()[1:]] else: self.caller = None @property def i(self): return self.vectorbase.absolute_i + self.relative_i @i.setter def i(self, value): self.vectorbase.absolute_i = value - self.relative_i def set_size(self, size): if self.size == size: return else: raise CompilerError( "Mismatch of instruction and register size:" " %s != %s" % (self.size, size) ) def set_vectorbase(self, vectorbase): if self.vectorbase is not self: raise CompilerError("Cannot assign one register" "to several vectors") self.relative_i = self.i - vectorbase.i self.vectorbase = vectorbase def _new_by_number(self, i, size=1): return Tape.Register(self.reg_type, self.program, size=size, i=i) def get_vector(self, base=0, size=None): if size is None: size = self.size if base == 0 and size == self.size: return self if size == 1: return self[base] res = self._new_by_number(self.i + base, size=size) res.set_vectorbase(self) self.create_vector_elements() res.vector = self.vector[base : base + size] return res def create_vector_elements(self): if self.vector: return elif self.size == 1: self.vector = [self] return self.vector = [] for i in range(self.size): reg = self._new_by_number(self.i + i) reg.set_vectorbase(self) self.vector.append(reg) def get_all(self): return self.vector or [self] def __getitem__(self, index): if self.size == 1 and index == 0: return self if not self.vector: self.create_vector_elements() return self.vector[index] def __len__(self): return self.size def copy(self): return Tape.Register(self.reg_type, Program.prog.curr_tape) def link(self, other): if Program.prog.options.noreallocate: raise CompilerError("reallocation necessary for linking, " "remove option -u") assert self.reg_type == other.reg_type self.duplicates |= other.duplicates for dup in self.duplicates: dup.duplicates = self.duplicates def update(self, other): """ Update register. Useful in loops like :py:func:`~Compiler.library.for_range`. :param other: any convertible type """ diff_block = isinstance(other, Tape.Register) and self.block != other.block other = type(self)(other) if not diff_block: self.program.start_new_basicblock(name="update") if self.program != other.program: raise CompilerError( 'cannot update register with one from another thread') self.link(other) @property def is_gf2n(self): return ( self.reg_type == RegType.ClearGF2N or self.reg_type == RegType.SecretGF2N ) @property def is_clear(self): return ( self.reg_type == RegType.ClearModp or self.reg_type == RegType.ClearGF2N or self.reg_type == RegType.ClearInt ) def __str__(self): return self.reg_type + str(self.i) + \ ("(%d)" % self.size if self.size is not None and self.size > 1 else "") __repr__ = __str__