Source code for Compiler.decision_tree

from Compiler.types import *
from Compiler.sorting import *
from Compiler.library import *
from Compiler import util, oram

from itertools import accumulate
import math

debug = False
debug_split = False
max_leaves = None

def get_type(x):
    if isinstance(x, (Array, SubMultiArray)):
        return x.value_type
    elif isinstance(x, (tuple, list)):
        x = x[0] + x[-1]
        if util.is_constant(x):
            return cint
        else:
            return type(x)
    else:
        return type(x)

def PrefixSum(x):
    return x.get_vector().prefix_sum()

def PrefixSumR(x):
    tmp = get_type(x).Array(len(x))
    tmp.assign_vector(x)
    break_point()
    tmp[:] = tmp.get_reverse_vector().prefix_sum()
    break_point()
    return tmp.get_reverse_vector()

def PrefixSum_inv(x):
    tmp = get_type(x).Array(len(x) + 1)
    tmp.assign_vector(x, base=1)
    tmp[0] = 0
    return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x))

def PrefixSumR_inv(x):
    tmp = get_type(x).Array(len(x) + 1)
    tmp.assign_vector(x)
    tmp[-1] = 0
    return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x))

class SortPerm:
    def __init__(self, x):
        B = sint.Matrix(len(x), 2)
        B.set_column(0, 1 - x.get_vector())
        B.set_column(1, x.get_vector())
        self.perm = Array.create_from(dest_comp(B))
    def apply(self, x):
        res = Array.create_from(x)
        reveal_sort(self.perm, res, False)
        return res
    def unapply(self, x):
        res = Array.create_from(x)
        reveal_sort(self.perm, res, True)
        return res

def Sort(keys, *to_sort, n_bits=None, time=False):
    if time:
        start_timer(1)
    for k in keys:
        assert len(k) == len(keys[0])
    n_bits = n_bits or [None] * len(keys)
    bs = Matrix.create_from(
        sum([k.get_vector().bit_decompose(nb)
             for k, nb in reversed(list(zip(keys, n_bits)))], []))
    get_vec = lambda x: x[:] if isinstance(x, Array) else x
    res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
                             for x in to_sort)
    res = res.transpose()
    if time:
        start_timer(11)
    radix_sort_from_matrix(bs, res)
    if time:
        stop_timer(11)
        stop_timer(1)
    res = res.transpose()
    return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f)
            if isinstance(get_vec(y), sfix)
            else x for (x, y) in zip(res, to_sort)]

def VectMax(key, *data, debug=False):
    def reducer(x, y):
        b = x[0] > y[0]
        if debug:
            print_ln('max b=%s', b.reveal())
        return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
    if debug:
        key = list(key)
        data = [list(x) for x in data]
        print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data))
    res = util.tree_reduce(reducer, zip(key, *data))[1:]
    if debug:
        print_ln('vect max res=%s', util.reveal(res))
    return res

def GroupSum(g, x):
    assert len(g) == len(x)
    p = PrefixSumR(x) * g
    pi = SortPerm(g.get_vector().bit_not())
    p1 = pi.apply(p)
    s1 = PrefixSumR_inv(p1)
    d1 = PrefixSum_inv(s1)
    d = pi.unapply(d1) * g
    return PrefixSum(d)

def GroupPrefixSum(g, x):
    assert len(g) == len(x)
    s = get_type(x).Array(len(x) + 1)
    s[0] = 0
    s.assign_vector(PrefixSum(x), base=1)
    q = get_type(s).Array(len(x))
    q.assign_vector(s.get_vector(size=len(x)) * g)
    return s.get_vector(size=len(x), base=1) - GroupSum(g, q)

def GroupMax(g, keys, *x):
    if debug:
        print_ln('group max input g=%s keys=%s x=%s', util.reveal(g),
                 util.reveal(keys), util.reveal(x))
    assert len(keys) == len(g)
    for xx in x:
        assert len(xx) == len(g)
    n = len(g)
    m = int(math.ceil(math.log(n, 2)))
    keys = Array.create_from(keys)
    x = [Array.create_from(xx) for xx in x]
    g_new = Array.create_from(g)
    g_old = g_new.same_shape()
    for d in range(m):
        w = 2 ** d
        g_old[:] = g_new[:]
        break_point()
        vsize = n - w
        g_new.assign_vector(g_old.get_vector(size=vsize).bit_or(
            g_old.get_vector(size=vsize, base=w)), base=w)
        b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
        for xx in [keys] + x:
            a = b.if_else(xx.get_vector(size=vsize),
                          xx.get_vector(size=vsize, base=w))
            xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else(
                xx.get_vector(size=vsize, base=w), a), base=w)
        break_point()
        if debug:
            print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(),
                     util.reveal(a), util.reveal(keys),
                     util.reveal(x), g_new.reveal())
    t = sint.Array(len(g))
    t[-1] = 1
    t.assign_vector(g.get_vector(size=n - 1, base=1))
    if debug:
        print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g),
                 util.reveal(t), util.reveal(keys), util.reveal(x))
    return [GroupSum(g, t[:] * xx) for xx in [keys] + x]

def ModifiedGini(g, y, debug=False):
    assert len(g) == len(y)
    y = [y.get_vector().bit_not(), y]
    u = [GroupPrefixSum(g, yy) for yy in y]
    s = [GroupSum(g, yy) for yy in y]
    w = [ss - uu for ss, uu in zip(s, u)]
    us = sum(u)
    ws = sum(w)
    uqs = u[0] ** 2 + u[1] ** 2
    wqs = w[0] ** 2 + w[1] ** 2
    res = sfix(uqs) / us + sfix(wqs) / ws
    if debug:
        print_ln('g=%s y=%s s=%s',
                 util.reveal(g), util.reveal(y),
                 util.reveal(s))
        print_ln('u0=%s', util.reveal(u[0]))
        print_ln('u0=%s', util.reveal(u[1]))
        print_ln('us=%s', util.reveal(us))
        print_ln('w0=%s', util.reveal(w[0]))
        print_ln('w1=%s', util.reveal(w[1]))
        print_ln('ws=%s', util.reveal(ws))
        print_ln('uqs=%s', util.reveal(uqs))
        print_ln('wqs=%s', util.reveal(wqs))
    if debug:
        print_ln('gini %s %s', type(res), util.reveal(res))
    return res

MIN_VALUE = -10000

def FormatLayer(h, g, *a):
    return CropLayer(h, *FormatLayer_without_crop(g, *a))

def FormatLayer_without_crop(g, *a, debug=False):
    for x in a:
        assert len(x) == len(g)
    v = [g.if_else(aa, 0) for aa in a]
    if debug:
        print_ln('format in %s', util.reveal(a))
        print_ln('format mux %s', util.reveal(v))
    v = Sort([g.bit_not()], *v, n_bits=[1])
    if debug:
        print_ln('format sort %s', util.reveal(v))
    return v

def CropLayer(k, *v):
    if max_leaves:
        n = min(2 ** k, max_leaves)
    else:
        n = 2 ** k
    return [vv[:min(n, len(vv))] for vv in v]

def TrainLeafNodes(h, g, y, NID):
    assert len(g) == len(y)
    assert len(g) == len(NID)
    Label = GroupSum(g, y.bit_not()) < GroupSum(g, y)
    return FormatLayer(h, g, NID, Label)

def GroupSame(g, y):
    assert len(g) == len(y)
    s = GroupSum(g, [sint(1)] * len(g))
    s0 = GroupSum(g, y.bit_not())
    s1 = GroupSum(g, y)
    if debug_split:
        print_ln('group same g=%s', util.reveal(g))
        print_ln('group same y=%s', util.reveal(y))
    return (s == s0).bit_or(s == s1)

def GroupFirstOne(g, b):
    assert len(g) == len(b)
    s = GroupPrefixSum(g, b)
    return s * b == 1

[docs]class TreeTrainer: """ Decision tree training by `Hamada et al.`_ :param x: sample data (by attribute, list or :py:obj:`~Compiler.types.Matrix`) :param y: binary labels (list or sint vector) :param h: height (int) :param binary: binary attributes instead of continuous :param attr_lengths: attribute description for mixed data (list of 0/1 for continuous/binary) :param n_threads: number of threads (default: single thread) .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 """ def ApplyTests(self, x, AID, Threshold): m = len(x) n = len(AID) assert len(AID) == len(Threshold) for xx in x: assert len(xx) == len(AID) e = sint.Matrix(m, n) AID = Array.create_from(AID) @for_range_multithread(self.n_threads, 1, m) def _(j): e[j][:] = AID[:] == j xx = sum(x[j] * e[j] for j in range(m)) if self.debug > 1: print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx)) print_ln('threshold %s', util.reveal(Threshold)) return 2 * xx < Threshold def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False): assert len(g) == len(x) assert len(g) == len(y) if time: start_timer(2) s = ModifiedGini(g, y, debug=debug or self.debug > 2) if time: stop_timer(2) if debug or self.debug > 1: print_ln('gini %s', s.reveal()) xx = x t = get_type(x).Array(len(x)) t[-1] = MIN_VALUE t.assign_vector(xx.get_vector(size=len(x) - 1) + \ xx.get_vector(size=len(x) - 1, base=1)) gg = g p = sint.Array(len(x)) p[-1] = 1 p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( xx.get_vector(size=len(x) - 1) == \ xx.get_vector(size=len(x) - 1, base=1))) break_point() if debug: print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p)) s = p[:].if_else(MIN_VALUE, s) t = p[:].if_else(MIN_VALUE, t[:]) if debug: print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) if time: start_timer(3) s, t = GroupMax(gg, s, t) if time: stop_timer(3) if debug: print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) return t, s def GlobalTestSelection(self, x, y, g): assert len(y) == len(g) for xx in x: assert(len(xx) == len(g)) m = len(x) n = len(y) u, t = [get_type(x).Matrix(m, n) for i in range(2)] v = get_type(y).Matrix(m, n) s = sfix.Matrix(m, n) @for_range_multithread(self.n_threads, 1, m) def _(j): single = not self.n_threads or self.n_threads == 1 time = self.time and single if debug: print_ln('run %s', j) @if_e(self.attr_lengths[j]) def _(): u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, n_bits=[util.log2(n), 1], time=time) @else_ def _(): u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, n_bits=[util.log2(n), None], time=time) if self.debug_threading: print_ln('global sort %s %s %s', j, util.reveal(u[j]), util.reveal(v[j])) t[j][:], s[j][:] = self.AttributeWiseTestSelection( g, u[j], v[j], time=time, debug=self.debug_selection) if self.debug_threading: print_ln('global attribute %s %s %s', j, util.reveal(t[j]), util.reveal(s[j])) n = len(g) a = sint.Array(n) if self.debug_threading: print_ln('global s=%s', util.reveal(s)) if self.debug_gini: print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)), *(ss[0].reveal() for ss in s)) if self.time: start_timer(4) if self.debug > 1: print_ln('s=%s', s.reveal_nested()) print_ln('t=%s', t.reveal_nested()) a[:], tt = VectMax((s[j][:] for j in range(m)), range(m), (t[j][:] for j in range(m)), debug=self.debug > 1) tt = Array.create_from(tt) if self.time: stop_timer(4) if self.debug > 1: print_ln('a=%s', util.reveal(a)) print_ln('tt=%s', util.reveal(tt)) return a[:], tt[:] def TrainInternalNodes(self, k, x, y, g, NID): assert len(g) == len(y) for xx in x: assert len(xx) == len(g) AID, Threshold = self.GlobalTestSelection(x, y, g) s = GroupSame(g[:], y[:]) if self.debug > 1 or debug_split: print_ln('AID=%s', util.reveal(AID)) print_ln('Threshold=%s', util.reveal(Threshold)) print_ln('GroupSame=%s', util.reveal(s)) AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold) if self.debug > 1 or debug_split: print_ln('AID=%s', util.reveal(AID)) print_ln('Threshold=%s', util.reveal(Threshold)) b = self.ApplyTests(x, AID, Threshold) layer = FormatLayer_without_crop(g[:], NID, AID, Threshold, debug=self.debug > 1) return *layer, b @method_block def train_layer(self, k): x = self.x y = self.y g = self.g NID = self.NID if self.debug > 1: print_ln('g=%s', g.reveal()) print_ln('y=%s', y.reveal()) print_ln('x=%s', x.reveal_nested()) self.nids[k], self.aids[k], self.thresholds[k], b = \ self.TrainInternalNodes(k, x, y, g, NID) if self.debug > 1: print_ln('layer %s:', k) for name, data in zip(('NID', 'AID', 'Thr'), (self.nids[k], self.aids[k], self.thresholds[k])): print_ln(' %s: %s', name, data.reveal()) NID[:] = 2 ** k * b + NID b_not = b.bit_not() if self.debug > 1: print_ln('b_not=%s', b_not.reveal()) g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b) y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1]) for i, xxx in enumerate(xx): x[i] = xxx def __init__(self, x, y, h, binary=False, attr_lengths=None, n_threads=None): assert not (binary and attr_lengths) if binary: attr_lengths = [1] * len(x) else: attr_lengths = attr_lengths or ([0] * len(x)) for l in attr_lengths: assert l in (0, 1) self.attr_lengths = Array.create_from(regint(attr_lengths)) Array.check_indices = False Matrix.disable_index_checks() for xx in x: assert len(xx) == len(y) n = len(y) self.g = sint.Array(n) self.g.assign_all(0) self.g[0] = 1 self.NID = sint.Array(n) self.NID.assign_all(1) self.y = Array.create_from(y) self.x = Matrix.create_from(x) self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] self.thresholds = self.x.value_type.Matrix(h, n) self.n_threads = n_threads self.debug_selection = False self.debug_threading = False self.debug_gini = False self.debug = False self.time = False
[docs] def train(self): """ Train and return decision tree. """ h = len(self.nids) @for_range(h) def _(k): self.train_layer(k) return self.get_tree(h)
[docs] def train_with_testing(self, *test_set, output=False): """ Train decision tree and test against test data. :param y: binary labels (list or sint vector) :param x: sample data (by attribute, list or :py:obj:`~Compiler.types.Matrix`) :param output: output tree after every level :returns: tree """ for k in range(len(self.nids)): self.train_layer(k) tree = self.get_tree(k + 1) if output: output_decision_tree(tree) test_decision_tree('train', tree, self.y, self.x, n_threads=self.n_threads) if test_set: test_decision_tree('test', tree, *test_set, n_threads=self.n_threads) return tree
def get_tree(self, h): Layer = [None] * (h + 1) for k in range(h): Layer[k] = CropLayer(k, self.nids[k], self.aids[k], self.thresholds[k]) Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID) return Layer
def DecisionTreeTraining(x, y, h, binary=False): return TreeTrainer(x, y, h, binary=binary).train()
[docs]def output_decision_tree(layers): """ Print decision tree output by :py:class:`TreeTrainer`. """ print_ln('full model %s', util.reveal(layers)) for i, layer in enumerate(layers[:-1]): print_ln('level %s:', i) for j, x in enumerate(('NID', 'AID', 'Thr')): print_ln(' %s: %s', x, util.reveal(layer[j])) print_ln('leaves:') for j, x in enumerate(('NID', 'result')): print_ln(' %s: %s', x, util.reveal(layers[-1][j]))
def pick(bits, x): if len(bits) == 1: return bits[0] * x[0] else: try: return x[0].dot_product(bits, x) except: return sum(aa * bb for aa, bb in zip(bits, x))
[docs]def run_decision_tree(layers, data): """ Run decision tree against sample data. :param layers: tree output by :py:class:`TreeTrainer` :param data: sample data (:py:class:`~Compiler.types.Array`) :returns: binary label """ h = len(layers) - 1 index = 1 for k, layer in enumerate(layers[:-1]): assert len(layer) == 3 for x in layer: assert len(x) <= 2 ** k bits = layer[0].equal(index, k) threshold = pick(bits, layer[2]) key_index = pick(bits, layer[1]) if key_index.is_clear: key = data[key_index] else: key = pick( oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) child = 2 * key < threshold index += child * 2 ** k bits = layers[h][0].equal(index, h) return pick(bits, layers[h][1])
def test_decision_tree(name, layers, y, x, n_threads=None, time=False): if time: start_timer(100) n = len(y) x = x.transpose().reveal() y = y.reveal() guess = regint.Array(n) truth = regint.Array(n) correct = regint.Array(2) parts = regint.Array(2) layers = [[Array.create_from(util.reveal(x)) for x in layer] for layer in layers] @for_range_multithread(n_threads, 1, n) def _(i): guess[i] = run_decision_tree([[part[:] for part in layer] for layer in layers], x[i]).reveal() truth[i] = y[i].reveal() @for_range(n) def _(i): parts[truth[i]] += 1 c = (guess[i].bit_xor(truth[i]).bit_not()) correct[truth[i]] += c print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, sum(correct), n, correct[0], parts[0], correct[1], parts[1]) if time: stop_timer(100)
[docs]class TreeClassifier: """ Tree classification with convenient interface. Uses :py:class:`TreeTrainer` internally. :param max_depth: the depth of the decision tree :param n_threads: number of threads used in training """ def __init__(self, max_depth, n_threads=None): self.max_depth = max_depth self.n_threads = n_threads @staticmethod def get_attr_lengths(attr_types): if attr_types == None: return None else: return [1 if x == 'b' else 0 for x in attr_types]
[docs] def fit(self, X, y, attr_types=None): """ Train tree. :param X: sample data with row-wise samples (sint/sfix matrix) :param y: binary labels (sint list/array) """ self.tree = TreeTrainer( X.transpose(), y, self.max_depth, attr_lengths=self.get_attr_lengths(attr_types), n_threads=self.n_threads).train()
[docs] def fit_with_testing(self, X_train, y_train, X_test, y_test, attr_types=None, output_trees=False, debug=False): """ Train tree with accuracy output after every level. :param X_train: training data with row-wise samples (sint/sfix matrix) :param y_train: training binary labels (sint list/array) :param X_test: testing data with row-wise samples (sint/sfix matrix) :param y_test: testing binary labels (sint list/array) :param attr_types: attributes types (list of 'b'/'c' for binary/continuous; default is all continuous) :param output_trees: output tree after every level :param debug: output debugging information """ trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth, attr_lengths=self.get_attr_lengths(attr_types), n_threads=self.n_threads) trainer.debug = debug trainer.debug_gini = debug trainer.debug_threading = debug > 1 self.tree = trainer.train_with_testing(y_test, X_test.transpose(), output=output_trees)
[docs] def predict(self, X): """ Use tree for prediction. :param X: sample data with row-wise samples (sint/sfix matrix) :returns: sint array """ res = sint.Array(len(X)) @for_range(len(X)) def _(i): res[i] = run_decision_tree(self.tree, X[i]) return res
[docs] def output(self): """ Output decision tree. """ output_decision_tree(self.tree)
[docs]def preprocess_pandas(data): """ Preprocess pandas data frame to suit :py:class:`TreeClassifier` by expanding non-continuous attributes to several binary attributes as a unary encoding. :returns: a tuple of the processed data and a type list for the :py:obj:`attr_types` argument. """ import pandas import numpy res = [] types = [] for i, t in enumerate(data.dtypes): if pandas.api.types.is_int64_dtype(t): res.append(data.iloc[:,i].to_numpy()) types.append('c') elif pandas.api.types.is_object_dtype(t): values = list(filter(lambda x: isinstance(x, str), list(data.iloc[:,i].unique()))) print('converting the following to unary:', values) if len(values) == 2: res.append(data.iloc[:,i].to_numpy() == values[1]) types.append('b') else: for value in values: res.append(data.iloc[:,i].to_numpy() == value) types.append('b') else: raise CompilerError('unknown pandas type: ' + t) res = numpy.array(res) res = numpy.swapaxes(res, 0, 1) return res, types