This document explains how to implement machine learning models with Torch in Concrete ML, leveraging Fully Homomorphic Encryption (FHE).
There are two approaches to build FHE-compatible deep networks:
-
Quantization Aware Training (QAT): This method requires using custom layers to quantize weights and activations to low bit-widths. Concrete ML works with Brevitas, a library that provides QAT support for PyTorch.
- Use
compile_brevitas_qat_model
to compile models in this mode.
- Use
-
Post Training Quantization (PTQ): This method allows to compile a vanilla PyTorch model. However, accuracy may decrease significantly when quantizing weights and activations to fewer than 7 bits. On the other hand, depending on the model size, quantizing with 6-8 bits can be incompatible with FHE constraints. Thus you need to determine the trade-off between model accuracy and FHE compatibility.
- Use
compile_torch_model
to compile models in this mode.
- Use
Both approaches require setting rounding_threshold_bits
parameter accordingly. You should experiment to find the best values, starting with an initial value of 6
. See here for more details.
{% hint style="info" %} See the common compilation errors page for explanations and solutions to some common errors raised by the compilation function. {% endhint %}
The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy. To use QAT, Brevitas QuantIdentity
nodes must be inserted in the PyTorch model, including one that quantizes the input of the forward
function.
import brevitas.nn as qnn
import torch.nn as nn
import torch
N_FEAT = 12
n_bits = 3
class QATSimpleNet(nn.Module):
def __init__(self, n_hidden):
super().__init__()
self.quant_inp = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc1 = qnn.QuantLinear(N_FEAT, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
self.quant2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc2 = qnn.QuantLinear(n_hidden, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
self.quant3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc3 = qnn.QuantLinear(n_hidden, 2, True, weight_bit_width=n_bits, bias_quant=None)
def forward(self, x):
x = self.quant_inp(x)
x = self.quant2(torch.relu(self.fc1(x)))
x = self.quant3(torch.relu(self.fc2(x)))
x = self.fc3(x)
return x
Once the model is trained, use compile_brevitas_qat_model
from Concrete ML to perform conversion and compilation of the QAT network. Here, 3-bit quantization is used for both the weights and activations. This function automatically identifies the number of quantization bits used in the Brevitas model.
from concrete.ml.torch.compile import compile_brevitas_qat_model
import numpy
torch_input = torch.randn(100, N_FEAT)
torch_model = QATSimpleNet(30)
quantized_module = compile_brevitas_qat_model(
torch_model, # our model
torch_input, # a representative input-set to be used for both quantization and compilation
rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)
{% hint style="warning" %}
If QuantIdentity
layers are missing for any input or intermediate value, the compile function will raise an error. See the common compilation errors page for an explanation.
{% endhint %}
The following example demonstrates a simple PyTorch model that implements a fully connected neural network with two hidden layers. The model is compiled with compile_torch_model
to use FHE.
import torch.nn as nn
import torch
N_FEAT = 12
n_bits = 6
class PTQSimpleNet(nn.Module):
def __init__(self, n_hidden):
super().__init__()
self.fc1 = nn.Linear(N_FEAT, n_hidden)
self.fc2 = nn.Linear(n_hidden, n_hidden)
self.fc3 = nn.Linear(n_hidden, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
from concrete.ml.torch.compile import compile_torch_model
import numpy
torch_input = torch.randn(100, N_FEAT)
torch_model = PTQSimpleNet(5)
quantized_module = compile_torch_model(
torch_model, # our model
torch_input, # a representative input-set to be used for both quantization and compilation
n_bits=6,
rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)
The quantization parameters, along with the number of neurons in each layer, determine the accumulator bit-width of the network. Larger accumulator bit-widths result in higher accuracy but slower FHE inference time.
QAT: Configure parameters such as bit_width
and weight_bit_width
. Set n_bits=None
in the compile_brevitas_qat_model
.
PTQ: Set the n_bits
value in the compile_torch_model
function. Manually determine the trade-off between accuracy, FHE compatibility, and latency.
The model can now perform encrypted inference.
x_test = numpy.array([numpy.random.randn(N_FEAT)])
y_pred = quantized_module.forward(x_test, fhe="execute")
In this example, the input values x_test
and the predicted values y_pred
are floating points. The quantization (respectively de-quantization) step is done in the clear within the forward
method, before (respectively after) any FHE computations.
You can perform the inference on clear data in order to evaluate the impact of quantization and of FHE computation on the accuracy of their model. See this section for more details.
There are two approaches:
quantized_module.forward(quantized_x, fhe="simulate")
: This method simulates FHE execution taking into account Table Lookup errors. De-quantization must be done in a second step as for actual FHE execution. Simulation takes into account thep_error
/global_p_error
parametersquantized_module.forward(quantized_x, fhe="disable")
: This method computes predictions in the clear on quantized data, and then de-quantize the result. The return value of this function contains the de-quantized (float) output of running the model in the clear. Calling this function on clear data is useful when debugging, but this does not perform actual FHE simulation.
{% hint style="info" %}
FHE simulation allows to measure the impact of the Table Lookup error on the model accuracy. You can adjust the Table Lookup error using p_error
/global_p_error
, as described in the approximate computation section.
{% endhint %}
Concrete ML supports a variety of PyTorch operators that can be used to build fully connected or convolutional neural networks, with normalization and activation layers. Moreover, many element-wise operators are supported.
torch.nn.identity
torch.clip
torch.clamp
torch.round
torch.floor
torch.min
torch.max
torch.abs
torch.neg
torch.sign
torch.logical_or, torch.Tensor operator ||
torch.logical_not
torch.gt, torch.greater
torch.ge, torch.greater_equal
torch.lt, torch.less
torch.le, torch.less_equal
torch.eq
torch.where
torch.exp
torch.log
torch.pow
torch.sum
torch.mul, torch.Tensor operator *
torch.div, torch.Tensor operator /
torch.nn.BatchNorm2d
torch.nn.BatchNorm3d
torch.erf, torch.special.erf
torch.nn.functional.pad
torch.reshape
torch.Tensor.view
torch.flatten
torch.unsqueeze
torch.squeeze
torch.transpose
torch.concat, torch.cat
torch.nn.Unfold
torch.Tensor.expand
torch.Tensor.to
-- for casting to dtype
torch.nn.Linear
torch.conv1d
,torch.nn.Conv1D
torch.conv2d
,torch.nn.Conv2D
torch.nn.AvgPool2d
torch.nn.MaxPool2d
Concrete ML also supports some of their QAT equivalents from Brevitas.
brevitas.nn.QuantLinear
brevitas.nn.QuantConv1d
brevitas.nn.QuantConv2d
brevitas.nn.QuantIdentity
torch.nn.CELU
torch.nn.ELU
torch.nn.GELU
torch.nn.HardSigmoid
torch.nn.Hardswish
torch.nn.HardTanh
torch.nn.LeakyReLU
torch.nn.LogSigmoid
torch.nn.Mish
torch.nn.PReLU
torch.nn.ReLU6
torch.nn.ReLU
torch.nn.SELU
torch.nn.Sigmoid
torch.nn.SiLU
torch.nn.Softplus
torch.nn.Softshrink
torch.nn.Softsign
torch.nn.Tanh
torch.nn.Tanhshrink
torch.nn.Threshold
-- partial support
{% hint style="info" %}
The equivalent versions from torch.functional
are also supported.
{% endhint %}
{% hint style="success" %} Zama 5-Question Developer Survey
We want to hear from you! Take 1 minute to share your thoughts and helping us enhance our documentation and libraries. 👉 Click here to participate. {% endhint %}