Secure Multiparty Computation in Python — Deep Dive

Building MPC protocols with MPyC

MPyC (Multiparty Computation in Python) provides a high-level framework where you write computations using secure types that look like regular Python variables. Install with pip install mpyc.

from mpyc.runtime import mpc

async def secure_average():
    await mpc.start()

    # Each party inputs their private value
    # secint(32) creates a 32-bit secure integer type
    secint = mpc.SecInt(32)
    my_salary = int(input("Enter your salary: "))
    salaries = mpc.input(secint(my_salary), senders=list(range(mpc.m)))

    # Compute sum on secret-shared values
    total = sum(salaries)
    count = len(salaries)

    # Reveal only the average
    average = await mpc.output(total / count)
    print(f"Average salary: {average}")

    await mpc.shutdown()

mpc.run(secure_average())

Run this script on three terminals with:

python secure_avg.py -M3 -I0  # Party 0
python secure_avg.py -M3 -I1  # Party 1
python secure_avg.py -M3 -I2  # Party 2

Each party enters their salary locally. MPyC handles secret sharing, communication, and reconstruction. Only the average is revealed.

Under the hood: Shamir secret sharing in MPyC

MPyC uses Shamir secret sharing over finite fields. For m parties with threshold t (default: t = (m-1)//2), each value is encoded as a polynomial of degree t:

# Conceptual implementation of Shamir sharing
import random

def share_secret(secret, num_parties, threshold, prime):
    """Split secret into Shamir shares."""
    # Random polynomial: f(x) = secret + a1*x + a2*x^2 + ... + at*x^t
    coeffs = [secret] + [random.randrange(prime) for _ in range(threshold)]

    shares = []
    for i in range(1, num_parties + 1):
        # Evaluate polynomial at x = i
        value = sum(c * pow(i, j, prime) for j, c in enumerate(coeffs)) % prime
        shares.append((i, value))
    return shares

def reconstruct_secret(shares, prime):
    """Reconstruct from threshold+1 shares using Lagrange interpolation."""
    secret = 0
    for i, (xi, yi) in enumerate(shares):
        numerator = denominator = 1
        for j, (xj, _) in enumerate(shares):
            if i != j:
                numerator = (numerator * (-xj)) % prime
                denominator = (denominator * (xi - xj)) % prime
        lagrange = (numerator * pow(denominator, -1, prime)) % prime
        secret = (secret + yi * lagrange) % prime
    return secret

Addition of shared values is free — each party adds their shares locally. Multiplication requires a sub-protocol (Beaver triples or resharing) that adds a communication round.

Secure comparisons and conditionals

MPC makes comparisons expensive because they require bit decomposition. MPyC provides secure comparison operators:

async def find_maximum():
    await mpc.start()
    secint = mpc.SecInt(32)

    # Each party inputs a value
    values = mpc.input(secint(int(input("Your value: "))), senders=list(range(mpc.m)))

    # Secure comparison — finds max without revealing individual values
    maximum = values[0]
    for v in values[1:]:
        is_greater = v > maximum  # Returns a secure bit
        maximum = mpc.if_else(is_greater, v, maximum)

    result = await mpc.output(maximum)
    print(f"Maximum value: {result}")
    await mpc.shutdown()

mpc.run(find_maximum())

Each comparison involves multiple rounds of communication. For sorting n elements securely, this means O(n log n) comparisons, each requiring communication — making secure sorting significantly slower than plaintext sorting.

Private set intersection (PSI)

PSI determines which elements two parties share in common without revealing elements that don’t match. This has direct applications in contact discovery, ad attribution, and threat intelligence sharing.

import hashlib
import secrets

def oprf_psi_client_prepare(client_set, shared_key):
    """Client blinds their elements before sending to server."""
    blinded = {}
    for element in client_set:
        # Hash element, then mask with random value
        h = int.from_bytes(hashlib.sha256(element.encode()).digest(), 'big')
        mask = secrets.randbelow(2**256)
        blinded[element] = (h ^ mask, mask)  # Simplified; real PSI uses elliptic curves
    return blinded

def compute_intersection_size(set_a_hashes, set_b_hashes):
    """Count matching elements without revealing non-matching ones."""
    return len(set_a_hashes & set_b_hashes)

Production PSI implementations use elliptic curve Diffie-Hellman or oblivious pseudorandom functions. Libraries like openmined/PSI provide optimized implementations handling millions of elements.

Secure aggregation for federated learning

Multiple clients train local models and want to average their gradients without revealing individual updates. Secure aggregation achieves this:

import numpy as np
from typing import List

class SecureAggregator:
    """Simplified pairwise masking secure aggregation (Bonawitz et al.)."""

    def __init__(self, num_clients: int, vector_size: int):
        self.num_clients = num_clients
        self.vector_size = vector_size

    def generate_pairwise_masks(self, client_id: int, seed_pairs: dict) -> np.ndarray:
        """Generate masks that cancel out during aggregation."""
        mask = np.zeros(self.vector_size)
        for other_id, seed in seed_pairs.items():
            rng = np.random.default_rng(seed)
            pairwise = rng.normal(0, 1, self.vector_size)
            if client_id < other_id:
                mask += pairwise
            else:
                mask -= pairwise  # Opposite sign ensures cancellation
        return mask

    def mask_gradient(self, gradient: np.ndarray, mask: np.ndarray) -> np.ndarray:
        """Client adds mask to their gradient before sending."""
        return gradient + mask

    def aggregate(self, masked_gradients: List[np.ndarray]) -> np.ndarray:
        """Server sums masked gradients — masks cancel, sum of true gradients remains."""
        return np.sum(masked_gradients, axis=0) / len(masked_gradients)

The key insight: each pair of clients shares a random seed. Client A adds a mask derived from that seed; Client B subtracts the same mask. When the server sums all masked gradients, the pairwise masks cancel perfectly, leaving only the true aggregate.

PySyft for federated computation

PySyft integrates MPC with federated learning, allowing data scientists to train models on distributed data:

# Conceptual PySyft workflow (API simplified for clarity)
import syft as sy

# Data owners create private datasets
domain = sy.orchestra.launch(name="hospital-a", port=8080)
dataset = sy.Dataset(
    name="patient-records",
    asset_list=[
        sy.Asset(name="vitals", data=patient_df, mock=mock_df)
    ]
)
domain.upload_dataset(dataset)

# Data scientist submits computation request
@sy.syft_function(input_policy=sy.ExactMatch, output_policy=sy.SingleExecutionExactOutput)
def compute_avg_bp(vitals):
    return vitals["blood_pressure"].mean()

# Hospital approves and runs computation
# Data scientist gets only the aggregate result
result = compute_avg_bp(vitals=domain.datasets["patient-records"]["vitals"])

PySyft separates code submission from execution — the data owner reviews what will run on their data before approving.

Performance characteristics and optimization

MPC performance depends heavily on the computation structure:

OperationCommunication roundsRelative cost
Addition01x
Scalar multiplication01x
Multiplication1100x
Comparison~32 (per bit)3,000x
Division~323,000x
Sorting (n elements)O(n log²n)Very expensive

Optimization strategies:

Minimize multiplicative depth. Restructure computation to reduce sequential multiplications. For example, computing a * b * c * d as (a * b) * (c * d) takes 2 rounds instead of 3.

Use preprocessing. Generate Beaver triples (correlated randomness) offline before parties provide inputs. The online phase then requires only cheap operations.

Batch operations. Send multiple messages in a single network round rather than one at a time.

Choose the right tool. For two-party computation with complex branching logic, garbled circuits (via emp-toolkit or obliv-c) outperform secret sharing. For many-party arithmetic, secret sharing (MPyC) is better.

Security considerations

Collusion threshold. With Shamir sharing and threshold t, up to t parties can collude without learning anything. If t+1 parties collude, all secrets are exposed. Choose the threshold based on your trust model.

Input validation. MPC guarantees privacy, not correctness of inputs. A malicious party can input arbitrary values. Range proofs or zero-knowledge proofs can verify inputs fall within expected bounds.

Output leakage. The revealed output itself may leak information. If two parties compute the maximum of their values, the loser learns their value isn’t the max — which reveals a bound on the winner’s value. Careful protocol design considers what the output reveals.

Network security. MPC protocols assume authenticated, encrypted channels between parties. Always use TLS for inter-party communication.

The one thing to remember: Production MPC in Python combines secret sharing (for arithmetic) with careful protocol design (minimizing communication rounds) — with MPyC providing the most accessible framework for multi-party computation and PySyft bridging MPC with federated machine learning.

pythonprivacysecure-computationcryptography

See Also