Source code for Compiler.dijkstra

""" This module implements `Dijkstra's algorithm
<https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm>`_ based on
oblivious RAM. """


from Compiler.oram import *

from Compiler.program import Program

ORAM = OptimalORAM

try:
    prog = program.Program.prog
    prog.set_bit_length(min(64, prog.bit_length))
except AttributeError:
    pass

class HeapEntry(object):
    fields = ['empty', 'prio', 'value']
    def __init__(self, int_type, *args):
        self.int_type = int_type
        if not len(args):
            raise CompilerError()
        if len(args) == 1:
            args = args[0]
        for field,arg in zip(self.fields, args):
            self.__dict__[field] = arg
    def data(self):
        return self.prio, self.value
    def __repr__(self):
        return '(' + ', '.join('%s=%s' % (field,self.__dict__[field]) \
                                   for field in self.fields) + ')'
    def __eq__(self, other):
        return self.__dict__ == other.__dict__
    def __gt__(self, other):
        return (1 - self.empty) * (1 - other.empty) * \
            (self.int_type(self.prio) > self.int_type(other.prio))
    def __iter__(self):
        for field in self.fields:
            yield self.__dict__[field]
    def __add__(self, other):
        return type(self)(self.int_type, (i + j for i,j in zip(self, other)))
    def __sub__(self, other):
        return type(self)(self.int_type, (i - j for i,j in zip(self, other)))
    def __xor__(self, other):
        return type(self)(self.int_type, (i ^ j for i,j in zip(self, other)))
    def __mul__(self, other):
        return type(self)(self.int_type, (other * i for i in self))
    __rxor__ = __xor__
    __rmul__ = __mul__
    def hard_conv_me(self, value_type):
        return type(self)(self.int_type, \
                              *(value_type.hard_conv(x) for x in self))
    def dump(self):
        print_ln('empty %s, prio %s, value %s', *(reveal(x) for x in self))

class HeapORAM(object):
    def __init__(self, size, oram_type, init_rounds, int_type, entry_size=None):
        if entry_size is None:
            entry_size = (32,log2(size))
        self.int_type = int_type
        self.oram = oram_type(size, entry_size=entry_size, \
                                  init_rounds=init_rounds, \
                                  value_type=int_type.basic_type)
    def __getitem__(self, index):
        return self.make_entry(*self.oram.read(index))
    def make_entry(self, value, empty):
        return HeapEntry(self.int_type, (empty,) + value)
    def __setitem__(self, index, value):
        self.oram.access(index, value.data(), True, new_empty=value.empty)
    def access(self, index, value, write):
        tmp = self.oram.access(index, value.data(), write)
        return self.make_entry(*tmp)
    def delete(self, index, for_real):
        self.oram.delete(index, for_real)
    def read_and_maybe_remove(self, index):
        entry, state = self.oram.read_and_maybe_remove(index)
        return self.make_entry(*entry), state
    def add(self, index, entry, state):
        self.oram.add(Entry(MemValue(index), \
                                [MemValue(i) for i in entry.data()], \
                                entry.empty), state=state)
    def __len__(self):
        return len(self.oram)

class HeapQ(object):
    def __init__(self, max_size, oram_type=ORAM, init_rounds=-1, int_type=sint, entry_size=None):
        if entry_size is None:
            entry_size = (32, log2(max_size))
        basic_type = int_type.basic_type
        self.max_size = max_size
        self.levels = log2(max_size)
        self.depth = self.levels - 1
        self.heap = HeapORAM(2**self.levels, oram_type, init_rounds, int_type, entry_size=entry_size)
        self.value_index = oram_type(max_size, entry_size=entry_size[1], \
                                         init_rounds=init_rounds, \
                                         value_type=basic_type)
        self.size = MemValue(int_type(0))
        self.int_type = int_type
        self.basic_type = basic_type
        print('heap: %d levels, depth %d, size %d, index size %d' % \
            (self.levels, self.depth, self.heap.oram.size, self.value_index.size))
    def update(self, value, prio, for_real=True):
        self._update(self.basic_type.hard_conv(value), \
                         self.basic_type.hard_conv(prio), \
                         self.basic_type.hard_conv(for_real))
    def pop(self, for_real=True):
        return self._pop(self.basic_type.hard_conv(for_real))
    def bubble_up(self, start):
        bits = bit_decompose(start, self.levels)
        bits.reverse()
        bits = [0] + floatingpoint.PreOR(bits, self.levels)
        bits = [bits[i+1] - bits[i] for i in range(self.levels)]
        shift = self.int_type.bit_compose(bits)
        childpos = MemValue(start * shift)
        @for_range(self.levels - 1)
        def f(i):
            parentpos = childpos.right_shift(1, self.levels + 1)
            parent, parent_state = self.heap.read_and_maybe_remove(parentpos)
            child, child_state = self.heap.read_and_maybe_remove(childpos)
            swap = parent > child
            new_parent, new_child = cond_swap(swap, parent, child)
            self.heap.add(childpos, new_child, child_state)
            self.heap.add(parentpos, new_parent, parent_state)
            self.value_index.access(new_parent.value, parentpos, swap)
            self.value_index.access(new_child.value, childpos, swap)
            childpos.write(parentpos)
    @method_block
    def _pop(self, for_real=True):
        Program.prog.curr_tape.\
            start_new_basicblock(name='heapq-pop')
        pop_for_real = for_real * (self.size != 0)
        entry = self.heap[1]
        self.value_index.delete(entry.value, for_real)
        last = self.heap[self.basic_type(self.size)]
        self.heap.access(1, last, pop_for_real)
        self.value_index.access(last.value, 1, for_real * (self.size != 1))
        self.heap.delete(self.basic_type(self.size), for_real)
        self.size -= self.int_type(pop_for_real)
        parentpos = MemValue(self.basic_type(1))
        @for_range(self.levels - 1)
        def f(i):
            childpos = 2 * parentpos
            left_child, l_state = self.heap.read_and_maybe_remove(childpos)
            right_child, r_state = self.heap.read_and_maybe_remove(childpos+1)
            go_right = left_child > right_child
            otherchildpos = childpos + 1 - go_right
            childpos += go_right
            child, other_child = cond_swap(go_right, left_child, right_child)
            child_state, other_state = cond_swap(go_right, l_state, r_state)
            parent, parent_state = self.heap.read_and_maybe_remove(parentpos)
            swap = parent > child
            new_parent, new_child = cond_swap(swap, parent, child)
            self.heap.add(childpos, new_child, child_state)
            self.heap.add(otherchildpos, other_child, other_state)
            self.heap.add(parentpos, new_parent, parent_state)
            self.value_index.access(new_parent.value, parentpos, swap)
            self.value_index.access(new_child.value, childpos, swap)
            parentpos.write(childpos)
        self.check()
        return entry.value
    @method_block
    def _update(self, value, prio, for_real=True):
        Program.prog.curr_tape.\
            start_new_basicblock(name='heapq-update')
        index, not_found = self.value_index.read(value)
        self.size += self.int_type(not_found * for_real)
        index = if_else(not_found, self.basic_type(self.size), index[0])
        self.value_index.access(value, self.basic_type(self.size), \
                                    not_found * for_real)
        self.heap.access(index, HeapEntry(self.int_type, 0, prio, value), for_real)
        self.bubble_up(index)
        self.check()
    def __len__(self):
        return self.size
    def check(self):
        if debug:
            for i in range(len(self.heap)):
                if ((2 * i + 1 < len(self.heap) and \
                         self.heap[i] > self.heap[2*i+1]) or \
                        (2 * i + 2 < len(self.heap) and \
                             self.heap[i] > self.heap[2*i+2])) and \
                             not self.heap[i].empty:
                    raise Exception('heap condition violated at %d' % i)
                if i >= self.size and not self.heap[i].empty:
                    raise Exception('wrong size at %d' % i)
                if i < self.size and self.heap[i].empty:
                    raise Exception('empty entry in heap at %d' % i)
                # if not self.heap[i].empty and \
                #         self.heap[i].value not in self.value_index:
                #     raise Exception('missing index at %d' % i)
            for value,(index,empty) in enumerate(self.value_index):
                if not empty and self.heap[index].value != value:
                    raise Exception('index violated at %d' % index)
        if debug_online:
            @for_range(self.max_size)
            def f(value):
                index, not_found = self.value_index.read(value)
                index, not_found = index[0].reveal(), not_found.reveal()
                @if_(not_found == 0)
                def f():
                    heap_value = self.heap[index].value.reveal()
                    @if_(heap_value != value)
                    def f():
                        print_ln('heap mismatch: %s:%s in index, %s in heap', \
                                     value, index, heap_value)
                        crash()            
    def dump(self, msg=''):
        print_ln(msg)
        print_ln('size: %s', self.size.reveal())
        print_str('heap:')
        if isinstance(self.heap.oram, LinearORAM):
            for entry in self.heap.oram.ram:
                print_str(' %s:%s,%s', entry.empty().reveal(), \
                              entry.x[0].reveal(), entry.x[1].reveal())
        else:
            for i in range(self.max_size+1):
                print_str(' %s:%s', *(x.reveal() for x in self.heap.oram[i]))
        print_ln()
        print_str('value index:')
        if isinstance(self.value_index, LinearORAM):
            for entry in self.value_index.ram:
                print_str(' %s:%s', entry.empty().reveal(), entry.x[0].reveal())
        else:
            for i in range(self.max_size):
                print_str(' %s:%s', i, self.value_index[i].reveal())
        print_ln()
        print_ln()

[docs]def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None, debug=False): """ Securely compute Dijstra's algorithm on a secret graph. See :download:`../Programs/Source/dijkstra_example.mpc` for an explanation of the required inputs. :param source: source node (secret or clear-text integer) :param edges: ORAM representation of edges :param e_index: ORAM representation of vertices :param oram_type: ORAM type to use internally (default: :py:func:`~Compiler.oram.OptimalORAM`) :param n_loops: when to stop (default: number of edges) :param int_type: secret integer type (default: sint) """ vert_loops = n_loops * e_index.size // edges.size \ if n_loops else -1 dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \ init_rounds=vert_loops, value_type=int_type) int_type = dist.value_type basic_type = int_type.basic_type #visited = ORAM(e_index.size) #previous = oram_type(e_index.size) Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \ int_type=int_type) if n_loops is not None: # put initialization in different timer stop_timer() start_timer(1) dist[source] = (0,0) Q.update(source, 0) if n_loops is not None: stop_timer(1) start_timer() last_edge = MemValue(basic_type(1)) i_edge = MemValue(int_type(0)) u = MemValue(basic_type(0)) running = MemValue(basic_type(1)) @for_range(n_loops or edges.size) def f(i): print_ln('loop %s', i) time() running.write(last_edge.bit_not().bit_or(Q.size > 0).bit_and(running)) u.write(if_else(last_edge, Q.pop(last_edge), u)) #visited.access(u, True, last_edge) i_edge.write(int_type(if_else(last_edge, e_index[u], i_edge))) v, weight, le = edges[i_edge] last_edge.write(le) i_edge.iadd(1) alt = int_type(dist[u][0]) + int_type(weight) #is_shorter = (alt < dist[v]) * (1 - visited[v]) dv, not_visited = dist.read(v) # relying on default dv negative here is_shorter = (alt < int_type(dv[0])) + not_visited is_shorter *= running dist.access(v, (basic_type(alt), u), is_shorter) #previous.access(v, u, is_shorter) Q.update(v, basic_type(alt), is_shorter) if debug: print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s, ' 'shorter: %s, running: %s, queue size: %s, last edge: %s', u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), not_visited.reveal(), is_shorter.reveal(), running.reveal(), Q.size.reveal(), last_edge.reveal()) return dist
[docs]def convert_graph(G): """ Convert a `NetworkX directed graph <https://networkx.org/documentation/stable/reference/classes/digraph.html>`_ to the cleartext representation of what :py:func:`dijkstra` expects. """ G = G.copy() for u in G: for v in G[u]: G[u][v].setdefault('weight', 1) edges = [None] * (2 * G.size()) e_index = [None] * (len(G)) i = 0 for v in sorted(G): e_index[v] = i for u in sorted(G[v]): edges[i] = [u, G[v][u]['weight'], 0] i += 1 if not G[v]: edges[i] = [v, 0, 0] i += 1 edges[i-1][-1] = 1 return list(filter(lambda x: x, edges)), e_index
[docs]def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint): """ Securely compute Dijstra's algorithm on a cleartext graph. :param G: directed graph with NetworkX interface :param source: source node (secret or clear-text integer) :param n_loops: when to stop (default: number of edges) :param int_type: secret integer type (default: sint) """ edges_list, e_index_list = convert_graph(G) edges = oram_type(len(edges_list), \ entry_size=(log2(len(G)), log2(len(G)), 1), \ init_rounds=0, value_type=int_type.basic_type) e_index = oram_type(len(e_index_list), entry_size=log2(len(edges_list)), \ value_type=int_type.basic_type) for i in range(n_loops or edges.size): cint(i).print_reg('edge') time() edges[i] = edges_list[i] vert_loops = n_loops * e_index.size // edges.size \ if n_loops else e_index.size for i in range(vert_loops): cint(i).print_reg('vert') time() e_index[i] = e_index_list[i] return dijkstra(source, edges, e_index, oram_type, n_loops, int_type)
def test_dijkstra_on_cycle(n, oram_type=ORAM, n_loops=None, int_type=sint): n_edges = 2 * n edges = oram_type(n_edges, entry_size=(log2(n),log2(n),1), init_rounds=0, value_type=int_type.basic_type) e_index = oram_type(n, entry_size=log2(n_edges), init_rounds=0, \ value_type=int_type.basic_type) @for_range(n_loops or edges.size) def f(i): cint(i).print_reg('edge') time() neighbour = ((i >> 1) + 2 * (i % 2) - 1 + n) % n edges[i] = (neighbour, 1, i % 2) vert_loops = n_loops * e_index.size // edges.size \ if n_loops else e_index.size @for_range(vert_loops) def f(i): cint(i).print_reg('vert') time() e_index[i] = 2 * i return dijkstra(0, edges, e_index, oram_type, n_loops, int_type) def test_dijkstra_on_complete(n, oram_type=ORAM, n_loops=None, int_type=sint): n_edges = n**2 edges = oram_type(n_edges, entry_size=(log2(n),log2(n),1), init_rounds=0, value_type=int_type.basic_type) e_index = oram_type(n, entry_size=log2(n_edges), init_rounds=0, \ value_type=int_type.basic_type) @for_range(n_loops or n) def f(i): @for_range(n_loops - 1 if n_loops else n - 1) def f(j): cint(i).print_reg('v1') cint(j).print_reg('v2') time() edges[i*n+j] = (j, 1, 0) edges[i*n+n-1] = (n - 1, 1, 1) if n_loops is not None: stop_timer() start_timer(2) @for_range(n_loops or n) def f(i): cint(i).print_reg('vert') time() e_index[i] = n * i if n_loops is not None: stop_timer(2) start_timer() return dijkstra(0, edges, e_index, oram_type, \ n_loops**2 if n_loops else n**2, int_type) class ExtInt(object): def __init__(self, x, inf=False): self.x = x self.inf = inf def __add__(self, other): if isinstance(other, ExtInt): return ExtInt(self.x + other.x, self.inf + other.inf) else: return ExtInt(self.x + other, self.inf) def __sub__(self, other): if isinstance(other, ExtInt): return ExtInt(self.x - other.x, self.inf - other.inf) else: return ExtInt(self.x - other, self.inf) def __rsub__(self, other): return ExtInt(other - self.x, -self.inf) def __mul__(self, other): if isinstance(other, ExtInt): raise Exception() return ExtInt(self.x * other.x, self.inf * other.inf) else: return ExtInt(self.x * other, self.inf * other) __radd__ = __add__ __rmul__ = __mul__ def __lt__(self, other): if isinstance(other, ExtInt): return ((1 - self.inf) * (1 - other.inf) * (self.x < other.x)) + \ other.inf else: return (1 - self.inf) * (self.x < other) def __gt__(self, other): if isinstance(other, ExtInt): return ((1 - self.inf) * (1 - other.inf) * (self.x > other.x)) + \ self.inf else: return 1 - (1 -self.inf) * (1 - (self.x > other)) def __repr__(self): if self.inf: return 'T' else: return str(self.x) class Vector(object): """ Works like a vector. """ def __add__(self, other): print('add', type(self)) res = type(self)(len(self)) @for_range(len(self)) def f(i): res[i] = self[i] + other[i] return res def __sub__(self, other): print('sub', type(other)) res = type(other)(len(self)) @for_range(len(self)) def f(i): res[i] = self[i] - other[i] return res def __mul__(self, other): if isinstance(other, Vector): res = type(self)(1) res[0] = ExtInt(0) @for_range(len(self)) def f(i): res[0] += self[i] * other[i] return res[0] else: print('mul', type(self)) res = type(self)(len(self)) @for_range_parallel(1024, len(self)) def f(i): res[i] = self[i] * other return res __rmul__ = __mul__ class VectorList(Vector, list): pass class VectorArray(Vector): def __init__(self, length, address=None): self.length = length if address is None: self.arrays = [Array(length, 's') for i in range(2)] else: self.arrays = [Array(length, 's', addr) \ for addr in (address,address+length)] def assign(self, values): @for_range(len(self)) def f(i): self[i] = values[i] def assign_all(self, value): self.arrays[0].assign_all(value.x) self.arrays[1].assign_all(value.inf) def __getitem__(self, index): return ExtInt(*[v[index] for v in self.arrays]) def __setitem__(self, index, value): self.arrays[0][index] = value.x self.arrays[1][index] = value.inf def __len__(self): return len(self.arrays[0]) class IntVectorArray(Vector, Array): def __init__(self, length): Array.__init__(self, length, 's') class Matrix(object): """ Guess what. """ def __init__(self, rows, columns): self.rows = rows self.columns = columns self.address = Array(2 * rows * columns, 's').address def __getitem__(self, index): return VectorArray(self.columns, self.address + 2 * self.columns * index) def __setitem__(self, index, value): self[index].assign(value) def __len__(self): return self.rows def assign_all(self, value): @for_range(len(self)) def f(i): self[i].assign_all(value) return self def updatevector(vector, index, value): @for_range_parallel(1024, len(vector)) def f(i): vector[i] += index[i] * (value - vector[i]) def binarymin(A): if len(A) == 1: return [1], A[0] else: half = len(A) // 2 A_prime = VectorArray(half) B = IntVectorArray(half) i = IntVectorArray(len(A)) @for_range_parallel(128, half) def f(j): B[j] = A[2*j] < A[2*j+1] A_prime[j] = if_else(B[j], A[2*j], A[2*j+1]) i_prime, min = binarymin(A_prime) @for_range_parallel(1024, half) def f(j): i[2*j] = B[j] * i_prime[j] i[2*j+1] = (1 - B[j]) * i_prime[j] return i, min def stupid_dijkstra(M, s, n_loops=None): if n_loops is not None: stop_timer() start_timer(1) P = Matrix(len(M), len(M)) P.assign_all(ExtInt(0)) d = VectorArray(len(M)) d.assign_all(ExtInt(0,True)) q = VectorArray(len(M)) q.assign_all(ExtInt(0)) d_prime = VectorArray(len(M)) updatevector(d, s, 0) if n_loops is not None: stop_timer(1) start_timer() @for_range(n_loops or len(M)) def f(i): if n_loops is not None: stop_timer() start_timer(2) d_prime.assign(d + q) k, min = binarymin(d_prime) updatevector(q, k, ExtInt(0,True)) if n_loops is not None: stop_timer(2) start_timer() @for_range(n_loops or len(M)) def f(j): a = (d + M[j]) * k c = a < d[j] P[j] = P[j] + c * (k - P[j]) d[j] += c * (a - d[j]) return d, P def convert_graph_to_matrix(G): M = Matrix(len(G), len(G)) M.assign_all(ExtInt(0,True)) for u in G: for v in G[u]: M[u][v] = ExtInt(G[u][v].get('weight', 1)) return M def test_stupid_dijkstra(G, source): return stupid_dijkstra(convert_graph_to_matrix(G), \ demux(bit_decompose(source, log2(len(G))))) def test_stupid_dijkstra_on_cycle(n, n_loops=None): if n_loops is not None: stop_timer() start_timer(1) M = Matrix(n, n) M.assign_all(ExtInt(0,True)) s = IntVectorArray(n) s.assign_all(0) s[0] = 1 @for_range(n) def f(i): M[i][(i+1)%n] = ExtInt(1) M[i][(i-1+n)%n] = ExtInt(1) if n_loops is not None: stop_timer(1) start_timer() return stupid_dijkstra(M, s, n_loops)