from __future__ import annotations
import math
from abc import abstractmethod
from typing import Any, Generic, Type, TypeVar
from Compiler import library as lib
from Compiler import util
from Compiler.GC.types import cbit, sbit, sbitint, sbits
from Compiler.program import Program
from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint,
regint, sint, sintbit)
from Compiler.oram import demux_array, get_n_threads
# Adds messages on completion of heavy computation steps
debug = False
# Finer grained trace of steps that the ORAM performs
# + runtime error checks
# Warning: reveals information and makes the computation insecure
trace = False
n_threads = 16
n_parallel = 1024
# Avoids any memory allocation if set to False
# Setting to False prevents some optimizations but allows for controlling the ORAMs outside of the main tape
allow_memory_allocation = True
def get_n_threads(n_loops):
if n_threads is None:
if n_loops > 2048:
return 8
else:
return None
else:
return n_threads
T = TypeVar("T", sint, sbitint)
B = TypeVar("B", sintbit, sbit)
[docs]class SqrtOram(Generic[T, B]):
"""Oblivious RAM using the "Square-Root" algorithm.
:param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
:param sint value_type: The secret type to use, defaults to sint.
:param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
:param int period: Leave at None, this parameter is used to recursively pass down the top-level period.
"""
# TODO: Preferably this is an Array of vectors, but this is currently not supported
# One should regard these structures as Arrays where an entry may hold more
# than one value (which is a nice property to have when using the ORAM in
# practise).
shuffle: MultiArray
stash: MultiArray
# A block has an index and data
# `shuffle` and `stash` store the data,
# `shufflei` and `stashi` store the index
shufflei: Array
stashi: Array
shuffle_used: Array
position_map: PositionMap
# The size of the ORAM, i.e. how many elements it stores
n: int
# The period, i.e. how many calls can be made to the ORAM before it needs to be refreshed
T: int
# Keep track of how far we are in the period, and coincidentally how large
# the stash is (each access results in a fake or real block being put on
# the stash)
t: cint
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None:
global debug, allow_memory_allocation
# Correctly initialize the shuffle (memory) depending on the type of data
if isinstance(data, MultiArray):
self.shuffle = data
self.n = len(data)
elif isinstance(data, sint):
self.n = math.ceil(len(data) // entry_length)
if (len(data) % entry_length != 0):
raise Exception('Data incorrectly padded.')
self.shuffle = MultiArray(
(self.n, entry_length), value_type=value_type)
self.shuffle.assign_part_vector(data.get_vector())
else:
raise Exception("Incorrect format.")
# Only sint is supported
if value_type != sint and value_type != sbitint:
raise Exception("The value_type must be either sint or sbitint")
# Set derived constants
self.value_type = value_type
self.bit_type: Type[B] = value_type.bit_type
self.index_size = util.log2(self.n) + 1 # +1 because signed
self.index_type = value_type.get_type(self.index_size)
self.entry_length = entry_length
self.size = self.n
if debug:
lib.print_ln(
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
self.shuffle_used = cint.Array(self.n)
# Random permutation on the data
self.shufflei = Array.create_from(
self.index_type(regint.inc(self.n)))
# Calculate the period if not given
# upon recursion, the period should stay the same ("in sync"),
# therefore it can be passed as a constructor parameter
self.T = int(math.ceil(
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
if debug and not period:
lib.print_ln('Period set to %s', self.T)
# Here we allocate the memory for the permutation
# Note that self.shuffle_the_shuffle mutates this field
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
self.permutation = Array.create_from(
self.index_type(regint.inc(self.n)))
# We allow the caller to postpone the initialization of the shuffle
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
# Note that if you do not initialize, the ORAM is insecure
if initialize:
# If the ORAM is not initialized with existing data, we can apply
# a small optimization by forgoing shuffling the shuffle, as all
# entries of the shuffle are equal and empty.
if empty_data:
random_shuffle = sint.get_secure_shuffle(self.n)
self.shufflei.secure_permute(random_shuffle)
self.permutation.assign(self.shufflei[:].inverse_permutation())
if trace:
lib.print_ln('Calculated inverse permutation')
else:
self.shuffle_the_shuffle()
else:
print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.')
# Initialize position map (recursive oram)
self.position_map = PositionMap.create(self.permutation, k + 1,
self.T, initialize=initialize)
# Initialize stash
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
self.stashi = Array(self.T, value_type=value_type)
self.t = MemValue(cint(0))
# Initialize temp variables needed during the computation
self.found_ = self.bit_type.Array(size=self.T)
self.j = MemValue(cint(0, size=1))
# To prevent the compiler from recompiling the same code over and over again, we should use @method_block
# However, @method_block requires allocation (of return address), which is not allowed when not in the main thread
# Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread
SqrtOram.shuffle_the_shuffle = lib.method_block(SqrtOram.shuffle_the_shuffle) if allow_memory_allocation else SqrtOram.shuffle_the_shuffle
SqrtOram.refresh = lib.method_block(SqrtOram.refresh) if allow_memory_allocation else SqrtOram.refresh
SqrtOram.reinitialize = lib.method_block(SqrtOram.reinitialize) if allow_memory_allocation else SqrtOram.reinitialize
@lib.method_block
def access(self, index: T, write: B, *value: T):
global trace,n_parallel
if trace:
@lib.if_e(write.reveal() == 1)
def _():
lib.print_ln('Writing to secret index %s', index.reveal())
@lib.else_
def __():
lib.print_ln('Reading from secret index %s', index.reveal())
value = self.value_type(value, size=self.entry_length).get_vector(
0, size=self.entry_length)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in a bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check wheterh the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
# Store the stash item into the result if found
# If the item is not in the stash, the result will simple remain 0
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
# If we are writing, we need to add the value
self.stash[i] += write * access_here * (value - entry)
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('Found item in stash')
@lib.else_
def __():
lib.print_ln('Did not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += write * \
found.bit_not() * (value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
@lib.if_((write * found.bit_not()).reveal())
def _():
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
), self.shuffle[physical_address].reveal(), physical_address)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
def write(self, index: T, *value: T):
global trace, n_parallel
if trace:
lib.print_ln('Writing to secret index %s', index.reveal())
if isinstance(value, tuple) or isinstance(value,list):
value = self.value_type(value, size=self.entry_length)
print(value, type(value))
elif isinstance(value, self.value_type):
value = self.value_type(*value, size=self.entry_length)
print(value, type(value))
else:
raise Exception("Cannot handle type of value passed")
print(self.entry_length, value, type(value),len(value))
self._write(index, *value)
@lib.method_block
def _write(self, index: T, *value: T):
value = MemValue(self.value_type(value))
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in an bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check wheterh the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
# We update the stash value
self.stash[i] += access_here * (value - entry)
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('Found item in stash')
@lib.else_
def __():
lib.print_ln('Did not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += found.bit_not() * \
(value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
@lib.if_(found.bit_not().reveal())
def _():
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
), self.shuffle[physical_address].reveal(), physical_address)
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
@lib.method_block
def read(self, index: T, *value: T):
global debug, trace, n_parallel
if trace:
lib.print_ln('Reading from secret index %s', index.reveal())
value = self.value_type(value)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
if debug:
lib.print_ln('Refreshing SqrtORAM')
lib.print_ln('t=%s according to me', self.t)
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in a bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check whether the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
lib.check_point()
# Store the stash item into the result if found
# If the item is not in the stash, the result will simple remain 0
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
# @lib.for_range(self.t)
# def _(i):
# lib.print_ln("stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('Found item in stash (found=%s)', found.reveal())
@lib.else_
def __():
lib.print_ln('Did not find item in stash (found=%s)', found.reveal())
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
__getitem__ = read
__setitem__ = write
def shuffle_the_shuffle(self) -> None:
"""Permute the memory using a newly generated permutation and return
the permutation that would generate this particular shuffling.
This permutation is needed to know how to map logical addresses to
physical addresses, and is used as such by the postition map."""
global trace
# Random permutation on n elements
random_shuffle = sint.get_secure_shuffle(self.n)
if trace:
lib.print_ln('Generated shuffle')
# Apply the random permutation
self.shuffle.secure_permute(random_shuffle)
if trace:
lib.print_ln('Shuffled shuffle')
self.shufflei.secure_permute(random_shuffle)
if trace:
lib.print_ln('Shuffled shuffle indexes')
lib.check_point()
# Calculate the permutation that would have produced the newly produced
# shuffle order. This can be calculated by regarding the logical
# indexes (shufflei) as a permutation and calculating its inverse,
# i.e. find P such that P([1,2,3,...]) = shufflei.
# this is not necessarily equal to the inverse of the above generated
# random_shuffle, as the shuffle may already be out of order (e.g. when
# refreshing).
self.permutation.assign(self.shufflei[:].inverse_permutation())
# If shufflei does not contain exactly the indices
# [i for i in range(self.n)],
# the underlying waksman network of 'inverse_permutation' will hang.
if trace:
lib.print_ln('Calculated inverse permutation')
def refresh(self):
"""Refresh the ORAM by reinserting the stash back into the shuffle, and
reshuffling the shuffle.
This must happen on the T'th (period) accesses to the ORAM."""
self.j.write(0)
# Shuffle and emtpy the stash, and store elements back into shuffle
@lib.for_range_opt(self.n)
def _(i):
@lib.if_(self.shuffle_used[i])
def _():
self.shuffle[i] = self.stash[self.j]
self.shufflei[i] = self.stashi[self.j]
self.j += 1
# Reset the clock
self.t.write(0)
# Reset shuffle_used
self._reset_shuffle_used()
# Reinitialize position map
self.shuffle_the_shuffle()
# Note that we skip here the step of "packing" the permutation.
# Since the underlying memory of the position map is already aligned in
# this packed structure, we can simply overwrite the memory while
# maintaining the structure.
self.position_map.reinitialize(self.permutation)
def reinitialize(self, data: T):
# Note that this method is only used during refresh, and as such is
# only called with a permutation as data.
# The logical addresses of some previous permutation are irrelevant and must be reset
self.shufflei.assign_vector(self.index_type(regint.inc(self.n)))
# Reset the clock
self.t.write(0)
# Reset shuffle_used
self._reset_shuffle_used()
# Note that the self.shuffle is actually a MultiArray
# This structure is preserved while overwriting the values using
# assign_vector
self.shuffle.assign_vector(self.value_type(
data[:], size=self.n * self.entry_length))
# Note that this updates self.permutation (see constructor for explanation)
self.shuffle_the_shuffle()
self.position_map.reinitialize(self.permutation)
def _reset_shuffle_used(self):
global allow_memory_allocation
if allow_memory_allocation:
self.shuffle_used.assign_all(0)
else:
@lib.for_range_opt(self.n)
def _(i):
self.shuffle_used[i] = cint(0)
class PositionMap(Generic[T, B]):
PACK_LOG: int = 3
PACK: int = 1 << PACK_LOG
n: int # n in the paper
depth: cint # k in the paper
value_type: Type[T]
def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None:
self.n = n
self.depth = MemValue(cint(k))
self.value_type = value_type
self.bit_type = value_type.bit_type
self.index_type = self.value_type.get_type(util.log2(n) + 1) # +1 because signed
@abstractmethod
def get_position(self, logical_address: _secret, fake: B) -> Any:
"""Retrieve the block at the given (secret) logical address."""
global trace
if trace:
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
def reinitialize(self, permutation: T):
"""Reinitialize this PositionMap.
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
we cannot simply call __init__ on self. Instead, we must take care to
reuse and overwrite the same memory.
"""
...
@classmethod
def create(cls, permutation: Array, k: int, period: int,
value_type: Type[T] = sint, **kwargs) -> PositionMap:
"""Creates a new PositionMap. This is the method one should call when
needing a new position map. Depending on the size of the given data, it
will either instantiate a RecursivePositionMap or
a LinearPositionMap."""
n = len(permutation)
global debug
if n / PositionMap.PACK <= period:
if debug:
lib.print_ln(
'Initializing LinearPositionMap at depth %s of size %s', k, n)
res = LinearPositionMap(permutation, value_type, k=k)
else:
if debug:
lib.print_ln(
'Initializing RecursivePositionMap at depth %s of size %s', k, n)
res = RecursivePositionMap(permutation, period, value_type, k=k,
**kwargs)
return res
class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
def __init__(self, permutation: Array, period: int,
value_type: Type[T] = sint, k: int = -1, **kwargs) -> None:
PositionMap.__init__(self, len(permutation), k=k)
pack = PositionMap.PACK
# We pack the permutation into a smaller structure, index with a new permutation
packed_size = int(math.ceil(self.n / pack))
packed_structure = MultiArray(
(packed_size, pack), value_type=value_type)
@lib.for_range(packed_size)
def _(i):
packed_structure[i] = Array.create_from(
permutation.get_vector(base=i * pack, size=pack))
SqrtOram.__init__(self, packed_structure, value_type=value_type,
period=period, entry_length=pack, k=self.depth,
**kwargs)
# Initialize random temp variables needed during the computation
self.block_index_demux: Array = self.bit_type.Array(self.T)
self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK)
@lib.method_block
def get_position(self, logical_address: T, fake: B) -> _clear:
super().get_position(logical_address, fake)
pack = PositionMap.PACK
pack_log = PositionMap.PACK_LOG
# The item at logical_address
# will be in block with index h (block.<h>)
# at position l in block.data (block.data<l>)
program = Program.prog
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(
logical_address).right_shift(pack_log, program.bit_length)))
l = self.value_type.bit_compose(
sbits.get_type(program.bit_length)(logical_address) & (pack - 1))
global trace
if trace:
print_at_depth(self.depth, '-> logical_address=%s: h=%s, l=%s', logical_address.reveal(), h.reveal(), l.reveal())
# @lib.for_range(self.t)
# def _(i):
# print_at_depth(self.depth, "stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
# The resulting physical address
p = MemValue(self.index_type(-1))
found: B = MemValue(self.bit_type(False))
# First we try and retrieve the item from the stash at position stash[h][l]
# Since h and l are secret, we do this by scanning the entire stash
# First we scan the stash for the block we need
self.block_index_demux.assign_all(0)
@lib.for_range_opt_multithread(get_n_threads(self.T), self.T)
def _(i):
self.block_index_demux[i] = ( self.stashi[i] == h) & self.bit_type(i < self.t)
# We can determine if the 'index' is in the stash by checking the
# block_index_demux array
found = sum(self.block_index_demux)
# Once a block is found, we use the following condition to pick the correct item from that block
demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux)
# Finally we use the conditions to conditionally write p
@lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type])
def p_(i):
# We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum
# Therefore we include the check (i < self.t)
return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i // pack< self.t)
p.write(p_())
if trace:
@lib.if_e(found.reveal() == 0)
def _(): print_at_depth(self.depth, 'Retrieve shuffle[%s]:', h.reveal())
@lib.else_
def __():
print_at_depth(self.depth, 'Retrieve dummy element from shuffle:')
# Then we try and retrieve the item from the shuffle (the actual memory)
# Depending on whether we found the item in the stash, we either
# block 'h' in which 'index' resides, or a random block from the shuffle
p_prime = self.position_map.get_position(h, found)
self.shuffle_used[p_prime] = cbit(True)
# The block retrieved from the shuffle
block_p_prime: Array = self.shuffle[p_prime]
if trace:
@lib.if_e(found.reveal() == 0)
def _():
print_at_depth(self.depth, 'Retrieved position from shuffle[%s]=(%s: %s)',
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
@lib.else_
def __():
print_at_depth(self.depth, 'Retrieved dummy position from shuffle[%s]=(%s: %s)',
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
# We add the retrieved block from the shuffle to the stash
self.stash[self.t].assign(block_p_prime[:])
self.stashi[self.t] = self.shufflei[p_prime]
# Increase t
self.t += 1
# if found or not fake
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
# Retrieve l'th item from block
# l is secret, so we must use linear scan
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(
pack)) & condition.expand_to_vector(pack))
@lib.for_range_opt(pack)
def _(i):
p.write((hit[i]).if_else(block_p_prime[i], p))
return p.reveal()
def reinitialize(self, permutation: T):
SqrtOram.reinitialize(self, permutation)
class LinearPositionMap(PositionMap):
physical: Array
used: Array
def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None:
PositionMap.__init__(self, len(data), value_type, k=k)
self.physical = data
self.used = self.bit_type.Array(self.n)
# Initialize random temp variables needed during the computation
self.physical_demux: Array = self.bit_type.Array(self.n)
@lib.method_block
def get_position(self, logical_address: T, fake: B) -> _clear:
"""
This method corresponds to GetPosBase in the paper.
"""
super().get_position(logical_address, fake)
global trace
if trace:
@lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal())
def _():
lib.runtime_error(
'logical_address must lie between 0 and self.n - 1')
fake = MemValue(self.bit_type(fake))
logical_address = MemValue(logical_address)
p: MemValue = MemValue(self.index_type(-1))
done: B = self.bit_type(False)
# In order to get an address at secret logical_address,
# we need to perform a linear scan.
self.physical_demux.assign_all(0)
@lib.for_range_opt_multithread(get_n_threads(self.n), self.n)
def condition_i(i):
self.physical_demux[i] = \
(self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \
| (fake & self.used[i].bit_not())
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
# We only need once, so we pick the first one we find
@lib.for_range_opt(self.n)
def _(i):
self.physical_demux[i] &= done.bit_not()
done.update(done | self.physical_demux[i])
# Retrieve the value from the physical memory obliviously
@lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type])
def calc_p(i):
return self.physical[i] * self.physical_demux[i]
p.write(calc_p())
# Update self.used
self.used.assign(self.used[:] | self.physical_demux[:])
if trace:
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))
def _():
lib.runtime_error(
'%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
return p.reveal()
def reinitialize(self, data : T):
self.physical.assign(data)
global allow_memory_allocation
if allow_memory_allocation:
self.used.assign_all(False)
else:
@lib.for_range_opt(self.n)
def _(i):
self.used[i] = self.bit_type(0)
def print_at_depth(depth: cint, message: str, *kwargs):
lib.print_str('%s', depth)
@lib.for_range(depth)
def _(i):
lib.print_char(' ')
lib.print_char(' ')
lib.print_ln(message, *kwargs)