Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compiler): complexity per node #788

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

umut-sahin
Copy link
Contributor

import numpy as np
from concrete import fhe

def f(x, y):
    return (x**2) + (y//2)

inputset = fhe.inputset(fhe.uint3, fhe.uint6)
configuration = fhe.Configuration(
    enable_unsafe_features=True,
    use_insecure_key_cache=True,
    insecure_key_cache_location=".keys",
)

compiler = fhe.Compiler(f, {"x": "encrypted", "y": "encrypted"})
circuit = compiler.compile(inputset, configuration, verbose=True)

prints

...

Bit-Width Assigned Computation Graph
--------------------------------------------------------------------------------
%0 = x                           # EncryptedScalar<uint3>        ∈ [0, 7]
%1 = y                           # EncryptedScalar<uint6>        ∈ [0, 63]
%2 = 2                           # ClearScalar<uint3>            ∈ [2, 2]
%3 = power(%0, %2)               # EncryptedScalar<uint7>        ∈ [0, 49]
%4 = 2                           # ClearScalar<uint3>            ∈ [2, 2]
%5 = floor_divide(%1, %4)        # EncryptedScalar<uint7>        ∈ [0, 31]
%6 = add(%3, %5)                 # EncryptedScalar<uint7>        ∈ [2, 77]
return %6
--------------------------------------------------------------------------------

...

Optimizer
--------------------------------------------------------------------------------
### Optimizer display
--- Circuit
  7 bits integers
  1 manp (maxi log2 norm2)
--- User config
  1.000000e+00 error per pbs call
  1.000000e-05 error per circuit call
-- Solution correctness
  For each pbs call:  1/210619, p_error (4.747897e-06)
  For the full circuit: 1/105997 global_p_error(9.434180e-06)
--- Complexity for the full circuit
  3.060000e+02 Millions Operations
-- Circuit Solution
CircuitSolution {
    circuit_keys: CircuitKeys {
        secret_keys: [
            SecretLweKey {
                identifier: 0,
                polynomial_size: 512,
                glwe_dimension: 4,
                description: "big-secret[#0 : partitions [0]]",
            },
            SecretLweKey {
                identifier: 1,
                polynomial_size: 4096,
                glwe_dimension: 1,
                description: "big-secret[#1 : partitions [1]]",
            },
            SecretLweKey {
                identifier: 2,
                polynomial_size: 739,
                glwe_dimension: 1,
                description: "small-secret[#2 : partitions [2]]",
            },
            SecretLweKey {
                identifier: 3,
                polynomial_size: 887,
                glwe_dimension: 1,
                description: "small-secret[#3 : partitions [3]]",
            },
        ],
        keyswitch_keys: [
            KeySwitchKey {
                identifier: 0,
                input_key: SecretLweKey {
                    identifier: 0,
                    polynomial_size: 512,
                    glwe_dimension: 4,
                    description: "big-secret[#0 : partitions [0]]",
                },
                output_key: SecretLweKey {
                    identifier: 2,
                    polynomial_size: 739,
                    glwe_dimension: 1,
                    description: "small-secret[#2 : partitions [2]]",
                },
                ks_decomposition_parameter: KsDecompositionParameters {
                    level: 3,
                    log2_base: 4,
                },
                unitary_cost: 9098525.0,
                description: "ks[#0 : partitions [0] -> [2]]",
            },
            KeySwitchKey {
                identifier: 1,
                input_key: SecretLweKey {
                    identifier: 1,
                    polynomial_size: 4096,
                    glwe_dimension: 1,
                    description: "big-secret[#1 : partitions [1]]",
                },
                output_key: SecretLweKey {
                    identifier: 3,
                    polynomial_size: 887,
                    glwe_dimension: 1,
                    description: "small-secret[#3 : partitions [3]]",
                },
                ks_decomposition_parameter: KsDecompositionParameters {
                    level: 4,
                    log2_base: 4,
                },
                unitary_cost: 29113481.0,
                description: "ks[#1 : partitions [1] -> [3]]",
            },
        ],
        bootstrap_keys: [
            BootstrapKey {
                identifier: 0,
                input_key: SecretLweKey {
                    identifier: 2,
                    polynomial_size: 739,
                    glwe_dimension: 1,
                    description: "small-secret[#2 : partitions [2]]",
                },
                output_key: SecretLweKey {
                    identifier: 0,
                    polynomial_size: 512,
                    glwe_dimension: 4,
                    description: "big-secret[#0 : partitions [0]]",
                },
                br_decomposition_parameter: BrDecompositionParameters {
                    level: 1,
                    log2_base: 23,
                },
                unitary_cost: 47296000.0,
                description: "pbs[#0 : partitions [2] -> [0]]",
            },
            BootstrapKey {
                identifier: 1,
                input_key: SecretLweKey {
                    identifier: 3,
                    polynomial_size: 887,
                    glwe_dimension: 1,
                    description: "small-secret[#3 : partitions [3]]",
                },
                output_key: SecretLweKey {
                    identifier: 1,
                    polynomial_size: 4096,
                    glwe_dimension: 1,
                    description: "big-secret[#1 : partitions [1]]",
                },
                br_decomposition_parameter: BrDecompositionParameters {
                    level: 1,
                    log2_base: 22,
                },
                unitary_cost: 203456512.0,
                description: "pbs[#1 : partitions [3] -> [1]]",
            },
        ],
        conversion_keyswitch_keys: [
            ConversionKeySwitchKey {
                identifier: 0,
                input_key: SecretLweKey {
                    identifier: 1,
                    polynomial_size: 4096,
                    glwe_dimension: 1,
                    description: "big-secret[#1 : partitions [1]]",
                },
                output_key: SecretLweKey {
                    identifier: 0,
                    polynomial_size: 512,
                    glwe_dimension: 4,
                    description: "big-secret[#0 : partitions [0]]",
                },
                ks_decomposition_parameter: KsDecompositionParameters {
                    level: 1,
                    log2_base: 25,
                },
                fast_keyswitch: false,
                unitary_cost: 0.0,
                description: "fks[#0 : partitions [1] -> [0]]",
            },
        ],
        circuit_bootstrap_keys: [],
        private_functional_packing_keys: [],
    },
    instructions_keys: [],
    crt_decomposition: [],
    complexity: 305751974.0,
    p_error: 4.747896820674186e-6,
    global_p_error: 9.43417962422087e-6,
    is_feasible: true,
    error_msg: "",
}###
--------------------------------------------------------------------------------

Statistics
--------------------------------------------------------------------------------
...
complexity_per_node: {
    %3: 56_394_525
    %5: 232_569_993
    %6: 2_048
}
...
complexity: 305_751_974
--------------------------------------------------------------------------------

This PR is not complete ATM!

Complexity of wopPBS is not calculated.

It turns out this is not as straightforward...

Fusing can lead to unexpected output.

def f(x, y):
    return (x**2) <= (y//2)

inputset = fhe.inputset(fhe.uint3, fhe.uint6)

results in

Bit-Width Assigned Computation Graph
--------------------------------------------------------------------------------
%0 = x                           # EncryptedScalar<uint3>        ∈ [0, 7]
%1 = y                           # EncryptedScalar<uint6>        ∈ [1, 63]
%2 = 2                           # ClearScalar<uint3>            ∈ [2, 2]
%3 = power(%0, %2)               # EncryptedScalar<uint6>        ∈ [0, 49]
%4 = 2                           # ClearScalar<uint3>            ∈ [2, 2]
%5 = floor_divide(%1, %4)        # EncryptedScalar<uint5>        ∈ [0, 31]
%6 = less_equal(%3, %5)          # EncryptedScalar<uint1>        ∈ [0, 1]
return %6
--------------------------------------------------------------------------------

...

complexity_per_node: {
    %6: 509_855_888
}

As %3 and %5 is fused in the internal table lookups of %6.

There are some inconsistencies between the compiler and the optimizer.

It's even present in the example code in this PR. It'd be best if we can resolve it and have tests with different combinations of operations.

Btw, this issue is not just originating from this PR. Subtraction for example is implemented as x + (-y) in the compiler, but the optimizer considers it a single linear operation and gives incorrect complexity, which is half of what it should have been.

@cla-bot cla-bot bot added the cla-signed label Apr 16, 2024
@umut-sahin umut-sahin requested review from BourgerieQuentin and removed request for aquint-zama April 16, 2024 09:09
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: aa3f1ef Previous: 34de883 Ratio
v0 PBS table generation 58507193 ns/iter (± 2035155) 58441405 ns/iter (± 1486671) 1.00
v0 PBS simulate dag table generation 36857855 ns/iter (± 368758) 36690256 ns/iter (± 214036) 1.00
v0 WoP-PBS table generation 67165152 ns/iter (± 969936) 67074158 ns/iter (± 358713) 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@rudy-6-4
Copy link
Contributor

rudy-6-4 commented May 3, 2024

There is no cost model for LevelledOp node provided by the compiler to the optimizer. The compiler send a cost 0 for all LevelledOp node. So solving the inconsistency is just putting 0 to Dot node.
Subtraction is converted by compiler to a LevelledOp, so it should have a cost zero. @umut-sahin Did you observe a Dot node here ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants