Vanilla Encoding example

여기서는 $z \in \mathbb{C}^N$를 다항식 $m(X) \in \mathbb{C}[X]/(X^N + 1)$로 encoding하는 간단한 경우를 다룰 것입니다.

이를 위해 $\sigma: \mathbb{C}[X]/(X^N + 1) \rightarrow \mathbb{C}^N$이라는 canonical embedding을 사용할 것인데, 이는 벡터를 decoding하고 encoding합니다.

다항식 $m(X)$을 vector $z$로 decoding하려면 이 다항식을 특정 값에서 평가해야 하는데, 이 값은 cyclotomic polynomial $\Phi_M(X) = X^N + 1$의 root일 것입니다. 이 $N$개의 root는 $\xi, \xi^3, …, \xi^{2N-1}$입니다.

따라서 다항식 $m(X)$을 decoding하려면 $\sigma(m) = (m(\xi), m(\xi^3), …, m(\xi^{2N-1})) \in \mathbb{C}^N$으로 정의합니다. $\sigma$는 isomorphism을 정의하는데, 이는 bijective homomorphism으로, 따라서 어떤 벡터든지 해당하는 다항식으로 고유하게 encoding되고 그 반대도 성립합니다.

여기서 $\sigma$는 decoding이므로, 역함수 $\sigma^{-1}$을 계산하는 것은 vector $z \in \mathbb{C}^N$를 해당 다항식으로 encoding하는 것이 됩니다. 따라서 문제는 vector $z \in \mathbb{C}^N$가 주어졌을 때 다항식 $m(X) = \sum_{i=0}^{N-1} \alpha_i X^i \in \mathbb{C}[X]/(X^N + 1)$을 찾는 것인데, 이 때 $\sigma(m) = (m(\xi), m(\xi^3), …, m(\xi^{2N-1})) = (z_1, …, z_N)$가 됩니다.

이를 더 깊게 파헤치면 다음과 같은 시스템이 나옵니다:

$\sum_{j=0}^{N-1} \alpha_j (\xi^{2i-1})^j = z_i, \quad i=1, …, N.$

이는 다음과 같은 선형 방정식으로 볼 수 있습니다:

$A \alpha = z$

여기서 $A$는 $(\xi^{2i-1})_{i=1}^{N}$의 Vandermonde 행렬이고, $\alpha$는 다항식 계수의 벡터이며, $z$는 encoding하려는 벡터입니다.

따라서 $\alpha = A^{-1}z$이고, $\sigma^{-1}(z) = \sum_{i=0}^{N-1} \alpha_i X^i \in \mathbb{C}[X]/(X^N + 1)$임을 알 수 있습니다.

여기서 $M=8$, $N=\frac{M}{2}=4$, $\Phi_M(X)=X^4+1$, 그리고 $\omega=e^{\frac{2i\pi}{8}}=e^{\frac{i\pi}{4}}$로 두겠습니다. 목표는 다음 벡터들을 encoding하고 decoding한 후, 그 다항식을 더하고 곱한 다음 이를 다시 decoding하는 것입니다: $[1, 2, 3, 4]$ 및 $[-1, -2, -3, -4]$.

cyclotomic_polynomial

다항식을 decoding하기 위해서는 단순히 power of an $M$-th root of unity에서 평가하면 됩니다. 여기서는 $\xi_M = \omega = e^{i\pi/4}$를 선택합니다.

$\xi$와 $M$이 주어지면 $\sigma$ 및 그 역함수 $\sigma^{-1}$을 정의할 수 있습니다. 각각 decoding과 encoding을 나타냅니다.

import numpy as np
from numpy.polynomial import Polynomial

np.set_printoptions(precision=3)    # 소수점 3자리만 출력하도록 setting

Vanilla Encoder Class 선언

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""

    def __init__(self, M: int):
        """
        Initialization of the encoder for M a power of 2.
        xi, which is an M-th root of unity will, be used as a basis for our computations.
        """
        # ========================= EDIT HERE =========================
        self.xi = np.exp(2 * np.pi * 1j / M)
        # =============================================================
        self.M = M

    @staticmethod
    def vandermonde(xi: np.complex128, M: int) -> np.array:
        """Computes the Vandermonde matrix from a m-th root of unity."""
        N = M // 2
        matrix = []
        # We will generate each row of the matrix
        for i in range(N):
            # For each row we select a different root
            # ========================= EDIT HERE =========================
            root = xi ** (2 * i + 1)
            # =============================================================
            row = []

            # Then we store its powers
            for j in range(N):
                # ========================= EDIT HERE =========================
                row.append(root ** j)
                # =============================================================
            matrix.append(row)
        return matrix

    def sigma_inverse(self, b: np.array) -> 'Polynomial':
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        A = CKKSEncoder.vandermonde(self.xi, self.M)

        # Then we solve the system
        """
        선형 방정식 Mx = y의 해는 np.linalg.solve(M, y)로 구할 수 있습니다.
        """
        # ========================= EDIT HERE =========================
        coeffs = np.linalg.solve(A, b)
        # =============================================================

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            # ========================= EDIT HERE =========================
            root = self.xi ** (2 * i + 1)
            # =============================================================
            output = p(root)
            outputs.append(output)
        return np.array(outputs)
# Set the parameters
M = 8

# Initialize our encoder
encoder = CKKSEncoder(M)

Vandermonde matrix란?

"""What is Vandermonde matrix?"""
matrix = encoder.vandermonde(encoder.xi, encoder.M)
print(np.array(matrix))
[[ 1.   +0.j     0.707+0.707j  0.   +1.j    -0.707+0.707j]
 [ 1.   +0.j    -0.707+0.707j  0.   -1.j     0.707+0.707j]
 [ 1.   +0.j    -0.707-0.707j  0.   +1.j     0.707-0.707j]
 [ 1.   +0.j     0.707-0.707j  0.   -1.j    -0.707-0.707j]]

Polynomial로 Encoding

b = np.array([1, 2, 3, 4])
p = encoder.sigma_inverse(b)
Polynomial(p.coef.round(3))

$ x ↦ (2.5+0j)+((-0+0.707j))x+((-0+0.5j))x^2+(0.707j)x^3 $

Decoding

b_reconstructed = encoder.sigma(p)
b_reconstructed
array([1.+1.110e-16j, 2.+1.110e-16j, 3.+5.551e-17j, 4.-2.220e-16j])$

Error

np.linalg.norm(b_reconstructed - b)

$2.7755575615628914e-16$

Addition

m1 = np.array([1, 2, 3, 4])
m2 = np.array([1, -2, 3, -4])

p1 = encoder.sigma_inverse(m1)
p2 = encoder.sigma_inverse(m2)
Polynomial(p1.coef.round(3))

$x ↦ (2.5+0j) + ((-0+0.707j))x + ((-0+0.5j))x^2 + (0.707j)x^3$

Polynomial(p2.coef.round(3))

$x ↦ (-0.5+0j)+((-0.707+0j))x+(-2.5j)x^2+((0.707+0j))x^3$

p_add = p1 + p2
Polynomial(p_add.coef.round(3))

$x ↦ (2+0j)+((-0.707+0.707j))x+((-0-2j))x^2+((0.707+0.707j))x^3$

Decoding 결과

encoder.sigma(p_add)
array([ 2.000e+00-9.197e-17j, -8.882e-16+2.220e-16j,
        6.000e+00+2.220e-16j, -4.441e-16+0.000e+00j])

Multiplication

poly_modulo = Polynomial([1, 0, 0, 0, 1])
poly_modulo

$x ↦ 1.0+0.0x+0.0x^2+0.0x^3+1.0x^4$

p_mult = p1 * p2
Polynomial(p_mult.coef.round(3))

$x↦(-1.25+0j)+((-1.768-0.354j))x+(-7j)x^2+((3.536-0.707j))x^3+((1.25+0j))x^4 +((1.768+0.354j))x^5+(0.5j)x^6$

이렇게 4차 함수를 넘어가기 때문에 위 다항식인 $x^4+1$(Cyclotomic polynomial)으로 나눠주어야 한다.

p_mult = p1 * p2 % poly_modulo
Polynomial(p_mult.coef.round(3))

$ x↦(-2.5+0j)+((-3.536-0.707j))x+((-0-7.5j))x^2+((3.536-0.707j))x^3$

Decoding 결과

encoder.sigma(p_mult)
array([  1.+7.216e-16j,  -4.+8.327e-16j,   9.-5.274e-15j, -16.-2.609e-15j])

Reference

https://blog.openmined.org/ckks-explained-part-1-simple-encoding-and-decoding/