安全多方计算语言介绍

0. 前言

本文介绍了一种可以用来编写安全多方计算程序的语言。本语言可以支持复杂的计算逻辑。这里我们将从输入、条件、循环和方法实现等方面对安全多方计算语言进行介绍。

1. 输入

输入是计算的基础,各方提供自己的输入,多方之间约定计算某个公式,并在不看到其他人输入的前提下得到计算结果。在安全多方计算语言中,输入的表现形式分为公开输入,和隐私输入两种。

1.1 公开输入(Clear Input)

公开输入将明文写在程序中,各方均可以看到该输入的具体数值,例如:

a = cint(2)
b = cint(5)

sum = a + b
print_ln('Result is %s', c)

公开的输入会被赋予cint的类型。

1.2 私密输入 (Secret Input)

私密输入有两种使用方式,我们将将先进行举例,再对其区别进行阐释。

例1

a = sint(2)
b = sint(5)

sum = a + b
print_ln('Result is %s', c.reveal())

例2

a = sint.get_raw_input_from(0)
b = sint.get_raw_input_from(1)

sum = a + b
print_ln('Result is %s', c.reveal())

注意,私密输入除了在数据类型上改为sint之外,在结果获得时也多了reveal()方法。

此外,我们不难发现,在对变量a,b进行赋值时,第二种方法使用了get.raw_input_from()的方法。其区别在于,前一种是常量的私密输入,从逻辑上讲这样的输入方式意义不大。后一种是用户数据的私密输入,get_raw_input_from(0) 和 get_raw_input_from(1) 分别表示收集第一个用户和第二个用户的私密输入数据。

除了整数型输入以外,安全多方计算语言还支持GF(2^n)型数据的计算。默认长度40bit。

2. 数学符号

安全多方计算语言是以运算为核心的语言,自然支持各种基本数学符号。

具体包括: +, -, *, <, <=, >, >=, ==, !=

2. 条件判断

安全多方计算语言支持条件的判断,有下列几种展现方式。

2.1 if

对 if 后面的条件进行判断,若条件满足则执行if下的逻辑,反之跳过。

if idx % 5 == 0:

  bd_val[idx] = raw_bit_dec[idx / 5]

2.2 if_then()...else_then()...end_if()

对条件进行判断,若条件满足,则执行if_then()后面的逻辑,否则则执行else_then()后面的逻辑。以end_if()为截止。

if_then(cint(0))
a[0] = 123
else_then()
a[0] = 789
end_if()

3. 循环

安全多方计算语言支持循环逻辑。具体实现方式如下。

例1

out_bytes[1] = sum(in_bytes[idx] for idx in range(1, 8, 2))

例2

for j in range(4):
  temp[j] = round_key[(i-1) * 4 + j]

例3

for i in range(1, numRounds):
  roundKey = createRoundKey(expandedKey, i)
  aesRound(roundKey)(state)

例4

@for_range(n)
def _(i):
  a[i] = i
  b[i] = i + 60

4. 数组

安全多方计算语言支持数组,新建一个数组需要指定数组长度以及数据类型。如下,

a = Array(n, sint)

下面是一种对数组进行赋值的方法。

对元素赋值

a[i] = 60

循环赋值

a = Array(n, sint)
b = Array(n, sint)
@for_range(n)
def _(i):
    a[i] = i
    b[i] = i + 60

6. 方法实现

安全多方计算语言支持通过方法实现逻辑。用户既可以通过“import”使用自带包中的方法,也可以自行实现方法。

例1

from path_oram import OptimalORAM

array = OptimalORAM(10000)
array[1] = 1
print_ln('%s', array[1].reveal())

例2

def test(actual, expected):
  if isinstance(actual, (sint, sgf2n)):
    actual = actual.reveal()
  print_ln('expected %s, got %s', expected, actual)

例3

def millionnaires():
  print_ln("Waiting for Alice's input")
  alice = sint.get_input_from(0)
  print_ln("Waiting for Bob's input")
  bob = sint.get_input_from(1)

  b = alice < bob
  print_ln('The richest is: %s', b.reveal())

7. 实战展示

下面将展示一些使用安全多方计算语言编写的复杂程序。希望可以给予大家一些编写灵感。

例1 Dijkstra_tutorial

import dijkstra
from path_oram import OptimalORAM

n = 1000

dist = dijkstra.test_dijkstra_on_cycle(n, OptimalORAM)

for i in range(n):
    print_ln('%s: %s', i, dist[i][0].reveal())

例2 Gale-Shapley_tutorial

from Compiler import gs
from Compiler.path_oram import OptimalORAM

mm = gs.Matchmaker(50, oram_type=OptimalORAM)
mm.init_hard()
mm.match()

例3 oram_tutorial

from path_oram import OptimalORAM

array = OptimalORAM(10000)
array[1] = 1
print_ln('%s', array[1].reveal())

例4 tpmpc_tutorial

from util import if_else

program.bit_length = 32

def millionnaires():
    """ Secure comparison, receiving input from each party via stdin """
    print_ln("Waiting for Alice's input")
    alice = sint.get_input_from(0)
    print_ln("Waiting for Bob's input")
    bob = sint.get_input_from(1)

    b = alice < bob
    print_ln('The richest is: %s', b.reveal())

def naive_search(n):
    """ Search secret list for private input from Bob """
    # hardcoded "secret" list from Alice - in a real application this should be a private input
    a = [sint(i) for i in range(n)]
    print_ln("Waiting for search input from Bob")
    b = sint.get_input_from(1)

    eq_bits = [x == b for x in a]
    b_in_a = sum(eq_bits)
    print_ln("Is b in Alice's list? %s", b_in_a.reveal())

def scalable_search(n):
    """ Search using SPDZ loop to avoid loop unrolling """
    array = Array(n, sint)

    @for_range(n)
    def _(i):
        array[i] = sint(i)

    print_ln("Waiting for search input from Bob")
    b = sint.get_input_from(1)

    # need to use MemValue and Array inside @for_range loops,
    # instead of basic sint/cint registers
    result = MemValue(sint(0))

    @for_range(100, n)
    def _(i):
        result.write(result + (array[i] == b))

    print_ln("Is b in Alice's list? %s", result.reveal())

def compute_intersection(a, b):
    """ Naive quadratic private set intersection.

    Returns: secret Array with intersection (padded to len(a)), and
    secret Array of bits indicating whether Alice's input matches or not """
    n = len(a)
    if n != len(b):
        raise CompilerError('Inconsistent lengths to compute_intersection')
    intersection = Array(n, sint)
    is_match_at = Array(n, sint)

    @for_range(n)
    def _(i):
        @for_range(n)
        def _(j):
            match = a[i] == b[j]
            is_match_at[i] += match
            intersection[i] = if_else(match, a[i], intersection[i]) # match * a[i] + (1 - match) * intersection[i]
    return intersection, is_match_at

def set_intersection_example(n):
    """Naive private set intersection on two Arrays, followed by computing the size and average of the intersection"""
    a = Array(n, sint)
    b = Array(n, sint)
    print_ln('Running PSI example')
    @for_range(n)
    def _(i):
        a[i] = i
        b[i] = i + 60
    intersection, is_match_at = compute_intersection(a,b)

    print_ln('Printing set intersection (0: not in intersection)')
    size = MemValue(sint(0))
    total = MemValue(sint(0))
    @for_range(n)
    def _(i):
        size.write(size + is_match_at[i])
        total.write(total + intersection[i])
        print_str('%s ', intersection[i].reveal())
    print_ln('\nIntersection size: %s', size.reveal())

    total_fixed = sfix()
    total_fixed.load_int(total.read())
    print_ln('Average in intersection: %s', (total_fixed / size.read()).reveal())



millionnaires()
naive_search(100)
scalable_search(10000)
set_intersection_example(100)

例5 vickrey

import util
from Compiler import types

import math
import re
r = re.search('(\D*)(\d*)', program.name)

if r.group(2):
    n_inputs = int(r.group(2))
else:
    n_inputs = 100

n_parties = 2
n_threads = int(math.ceil(2 ** (int(math.log(n_inputs, 2) - 7))))
n_loops = 1
n_bits = 64
#value_type = types.get_sgf2nuint(n_bits)
value_type = sint

program.set_bit_length(n_bits)
program.set_security(40)

print_ln('n_inputs = %s, n_parties = %s, n_threads = %s, n_loops = %s, '
         'value_type = %s',
         n_inputs, n_parties, n_threads, n_loops, value_type.__name__)

@for_range(n_loops)
def f(_):
    Bid = types.getNamedTupleType('party', 'price')
    bids = Bid.get_array(n_inputs, value_type)

    for i in range(n_inputs):
        # i * 10 because inputs are all zero by default
        bids[i] = Bid(i, value_type.get_raw_input_from(i % n_parties) + i * 10)
        #bids = [Bid(i, value_type(i * 10)) for i in range(n_parties)]

    def bid_sort(a, b):
        comp = a.price < b.price
        res = util.cond_swap(comp, a, b)
        for i in res:
            i.price = value_type.hard_conv(i.price)
        return res

    def first_and_second(left, right):
        top = left[0].price < right[0].price
        cross = [left[i].price < right[1-i].price for i in range(2)]
        first = top.if_else(right[0], left[0])
        tmp = [cross[i].if_else(right[1-i], left[i]) for i in (0,1)]
        second = top.if_else(*tmp)
        for i in (first, second):
            i.price = value_type.hard_conv(i.price)
        return first, second

    results = Bid.get_array(2 * n_threads, value_type)

    def thread():
        i = get_arg()
        n_per_thread = n_inputs / n_threads
        if n_per_thread % 2 != 0:
            raise Exception('Number of inputs must be divisible by 2')
        start = i * n_per_thread
        tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \
                  for j in range(n_per_thread / 2)]
        first, second = util.tree_reduce(first_and_second, tuples)
        results[2*i], results[2*i+1] = first, second

    tape = program.new_tape(thread)
    threads = [program.run_tape(tape, i) for i in range(n_threads)]
    for i in threads:
        program.join_tape(i)

    tuples = [(results[2*i], results[2*i+1]) for i in range(n_threads)]
    first, second = util.tree_reduce(first_and_second, tuples)

    print_ln('Winner: %s, price: %s', first.party.reveal(), second.price.reveal())

例6 fixed_point_tutorial

program.bit_length = 80
print "program.bit_length: ", program.bit_length
program.security = 40

n = 10
m = 5

# array of fixed points
A = Array(n, sfix)

for i in range(n):
    A[i] = sfix(i)

print_ln('mrray of fixed points')
for i in range(n):
    print_ln('%s', A[i].reveal())

# matrix of fixed points
M = Matrix(n, m, sfix)

for i in range(n):
    for j in range(m):
        M[i][j] = sfix(i*j)

print_ln('matrix of fixed points')
for i in range(n):
    for j in range(m):
        print_str('%s ', M[i][j].reveal())
    print_ln(' ')


# assign scalar to sfix
A[5] = sfix(1.12345)
print_ln('%s', A[5].reveal())

AC = Array(n, cfix)

for i in range(n):
    AC[i] = cfix(1.5 * i)

for i in range(n):
    print_ln('%s', AC[i])

# assign sint to sfix
s = sint(10)
sa = sfix(); sa.load_int(s)
print_ln('successfully assigned sint to sfix %s', sa.reveal())

# division between fixed points
sb = sfix(2.5)
print_ln('division between %s %s = %s', sa.reveal(), sb.reveal(), (sa/sb).reveal())