Quantization of Neural Networks for Fully Homomorphic Encryption
Machine Learning and the Need for Privacy
In machine learning, a neural network is a circuit of artificial neurons or nodes. Neural networks are the standard used to produce high quality models over unstructured data (images, sound and text). When it comes to solving tasks like image segmentation, raw audio generation, or machine translation, machine learning models based on neural networks provide by far the best performance. The rise of neural networks can be summarized in two main points:
- Hardware infrastructure (i.e. GPU and distributed training. Fast inference even for huge models) 
- Model complexity (essentially number of parameters) 
With ever more complexity and overhead, it’s more and more difficult for machine learning specialists to build, maintain, and serve these models, and as a result, a few corporations own most of the artificial intelligence (AI) space. As a result, most ML models running in the cloud are only accessible via an API.. Trust and privacy are becoming important topics as the end users lack visibility and thus privacy is at stake.
In this article, we’re going to look at how Fully Homomorphic Encryption offers a unique and unmatched solution to tackle the problem of trust and privacy in data and AI. We’ll focus on Quantization, which is all about mapping continuous infinite values to a smaller set of discrete finite values for efficiency reasons. Quantization is also a way to make existing models more FHE-friendly, and we’ll explain what that means below.
Fully Homomorphic Encryption
Fully Homomorphic Encryption (FHE) allows you to compute on encrypted data. The computation can be done with a series of publicly-available keys without endangering the security. The encrypted result is returned to the legitimate owner, who is the only one able to decrypt the message with the private key. Mathematically, it is computationally unachievable for the untrusted server to decrypt anything, since it does not ‘know’ the keys.
Here at Zama, we use FHE to turn neural networks into equivalent computations that work end to end over encrypted data. The goal is to make this as efficient as possible while limiting any drop in accuracy. Programmable Bootstrapping (PBS) is the technique that makes this possible. This is a critical aspect in machine learning since activation functions are essentially nonlinear. In the rest of this post, we’ll assimilate the PBS to a simple table lookup over encrypted data.
Quantization and FHE-friendly Neural Networks
FHE allows you to use a large panel of operations that can be applied to machine learning models effectively. In a previous post, we showed you a practical example for a linear model; neural networks work in a similar way.
Zama initially used an approximate solution to convert neural networks into their FHE counterparts. The problem with this solution was that handling the FHE parameters for different neural network architectures is not trivial, and so we had to shift toward an exact paradigm where FHE can provide the exact same results as in clear.
The limitations of this new paradigm are two fold:
- Only multiplication by integers is possible. 
- The precision of PBS is limited (i.e.: the table lookups we can apply have a maximal size, typically 2⁷ or 128 different values; the larger the precision needed, the slower the inference). 
We can tackle these limitations with neural network quantization. Quantization is generally used to improve the efficiency or compression of neural networks, but it is also well adapted to circumventing some of the limits of FHE:
- Quantization can replace floating point value multiplications by integer multiplications, which are doable in FHE. 
- Quantization even allows reduction to small integers, which is a way to avoid the limited precision in PBS. 
Quantization in Neural Networks
Quantization of neural networks is a subdomain of neural network compression. Such techniques are mainly used to improve neural network inference time on embedded hardware or browser-like interfaces. The goal is to replace costly operations done on floats by much more efficient and relatively good approximations done on small integers. Quantization is currently implemented in some of the best-known deep learning libraries. A common approach is the affine quantization using p bits of precision. In a nutshell, when you want to reach lower precision than 8 bits, Quantization Aware Training is used to maintain a high accuracy between the quantized network and its floating points values counterpart.
Here’s a brief overview of how quantization works. Let X be a set of weights, inputs or activations outputs we want to quantize. Here is an example doing unsigned quantization (i.e. the quantization outputs integers that are greater or equal to 0):
| import numpy as np | |
| # Set seed for reproducibility | |
| np.random.seed(42) | |
| # Set the number of bits over which data will be quantized | |
| n_bits = 2 | |
| # Generate random data to be quantized | |
| X = np.random.rand(10) | |
| # Output: [0.37454012 0.95071431 0.73199394 0.59865848 0.15601864 | |
| # 0.15599452 0.05808361 0.86617615 0.60111501 0.70807258] | |
| max_X = np.max(X) | |
| # Output: 0.9507 | |
| min_X = np.min(X) | |
| # Output: 0.0581 | |
| max_q_value = 2 ** n_bits - 1 | |
| # Output: 127 | |
| range = max_X - min_X | |
| # Output: 0.8926 | |
| scale = range / max_q_value | |
| # Output: 0.2975 | |
| Zp = np.round((-min_X * max_q_value) / range) | |
| # Output: 0 | |
| q_X = np.round(X/scale) + Zp | |
| # Output: [1. 3. 2. 2. 1. 1. 0. 3. 2. 2.] | |
| # We can now obtain the dequantized output | |
| X_dequant = (q_X - Zp) * scale | |
| # Output: [0.29754356 0.89263069 0.59508713 0.59508713 0.29754356 | |
| # 0.29754356 0. 0.89263069 0.59508713 0.59508713] | 
With the above equations, you can transform the entire neural network from floating point values to integers over p bits of precision and adapt the matrix multiplication as it is done in a fully connected layer or in a convolutional neural network (CNN) by replacing real values by their quantized values.
FHE-friendly Neural Network
Moving forward, we’ll use PyTorch and Numpy to advance our example. We also make use of Concrete Numpy and some preliminary versions of the tools that allow us to directly compile torch models. Note that the torch compilation functions are in a very early stage of development for now, while the numpy compilation functions are more robust. The next version of our framework will contain more stable torch compilation functions. In the compilation, we also include quantization in a transparent way for the user.
Now let’s start the coding by defining a standard pytorch model and use concrete.torch to convert it to a full numpy model.
| # Import requirements | |
| # Torch and Numpy | |
| from torch import nn | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| # Set seeds for reproducibility | |
| torch.manual_seed(0) | |
| np.random.seed(0) | |
| # Torch to Numpy from concrete | |
| from concrete.torch import NumpyModule | |
| # Get MNIST | |
| transform=transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| test_data = torchvision.datasets.MNIST('../data', train=False, | |
| transform=transform, download=True) | |
| test_loader = torch.utils.data.DataLoader(test_data, 10000, shuffle=False) | |
| # Define a fully connected torch model | |
| n_features = 28*28 | |
| class FC_model(nn.Module): | |
| def __init__(self): | |
| super(FC_model, self).__init__() | |
| self.fc1 = nn.Linear(in_features=n_features, out_features=128) | |
| self.sigmoid1 = nn.Sigmoid() | |
| self.fc2 = nn.Linear(in_features=128, out_features=64) | |
| self.sigmoid2 = nn.Sigmoid() | |
| self.fc3 = nn.Linear(in_features=64, out_features=10) | |
| def forward(self, x): | |
| out = self.fc1(x) | |
| out = self.sigmoid1(out) | |
| out = self.fc2(out) | |
| out = self.sigmoid2(out) | |
| out = self.fc3(out) | |
| return out | |
| # You can train the model using this script https://gist.github.com/jfrery/2c362c01f71917b59700e906063518ef | |
| PATH = "mnist_97p.pt" | |
| torch_fc_model = FC_model() | |
| torch_fc_model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu'))) | |
| torch_fc_model.eval() | |
| # Carry torch model to numpy (required in by concrete) | |
| numpy_fc_model = NumpyModule(torch_fc_model) | |
| # Create random inputs of (n_examples, n_features) | |
| mnist_test_data, mnist_test_target = next(iter(test_loader)) | |
| mnist_test_data = mnist_test_data.detach().numpy() | |
| mnist_test_target = mnist_test_target.detach().numpy() | |
| mnist_test_data = mnist_test_data.reshape(mnist_test_data.shape[0], mnist_test_data.shape[1]* mnist_test_data.shape[2]*mnist_test_data.shape[3]) | |
| # Check that both model give same output | |
| # Accuracy pytorch | |
| (torch_fc_model(torch.from_numpy(mnist_test_data)).detach().numpy().argmax(1) == mnist_test_target).mean() | |
| # Output: 0.9733 | |
| # Accuracy numpy | |
| (numpy_fc_model(mnist_test_data).argmax(1) == mnist_test_target).mean() | |
| # Output: 0.9733 | 
In this snippet of code, we have a fully connected network in numpy that has a 97% accuracy on the MNIST test set.
Then, in the next snippet, we can start playing with quantization provided in the framework to transform our model to integers. We select a precision of 6 bits for our weights, activations, and layer outputs.
| # Quantization | |
| # Specify a number of bits for the quantization and chose whether we want signed or unsigned quantization. | |
| n_bits = 6 | |
| is_signed = False | |
| # Quantization from concrete | |
| from concrete.quantization import PostTrainingAffineQuantization | |
| # Quantized Array from concrete | |
| from concrete.quantization import QuantizedArray | |
| # Quantize our model with | |
| pt_quant = PostTrainingAffineQuantization(n_bits = n_bits, numpy_model = numpy_fc_model, is_signed = is_signed) | |
| # Calibrate layers and activations | |
| quant_module = pt_quant.quantize_module(mnist_test_data) | |
| # Quantize input | |
| q_mnist_test_data = QuantizedArray(n_bits = n_bits, values=mnist_test_data, is_signed=is_signed) | |
| # Get the position of different value (MNIST has a lot of black pixels) | |
| arg_diff_values = (mnist_test_data != -0.42421296) | |
| # Compare dequantized input value vs real input values | |
| # Real input | |
| mnist_test_data[arg_diff_values][:16] | |
| # Output: array([0.64495873, 1.9305104 , 1.5995764 , 1.4977505 , 0.33948106, | |
| # 0.03400347, 2.401455 , 2.8087585 , 2.8087585 , 2.8087585 , | |
| # 2.8087585 , 2.6432915 , 2.0959773 , 2.0959773 , 2.0959773 , | |
| # 2.0959773 ],) | |
| # Dequantized input | |
| q_mnist_test_data.dequant()[arg_diff_values][:16] | |
| # Output: array([0.66974755, 1.90620457, 1.59709032, 1.49405223, 0.3606333 , | |
| # 0.05151904, 2.42139499, 2.83354733, 2.83354733, 2.83354733, | |
| # 2.83354733, 2.62747116, 2.11228074, 2.11228074, 2.11228074, | |
| # 2.11228074]) | |
| # Check the quantized input values | |
| q_mnist_test_data.qvalues[arg_diff_values][:16] | |
| # Output: array([21, 45, 39, 37, 15, 9, 55, 63, 63, 63, 63, 59, 49, 49, 49, 49]) | |
| # Check the quantized weights for the first layer | |
| quant_module.quant_layers_dict['fc1'].q_weights.qvalues[0][:16] | |
| # Output: array([31, 30, 33, 31, 32, 33, 32, 32, 31, 32, 33, 30, 33, 34, 32, 33]) | |
| # Make sure all values are integers with 2**6 (64) values | |
| np.unique(q_mnist_test_data.qvalues[arg_diff_values]) | |
| # Output: array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, | |
| # 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, | |
| # 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, | |
| # 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]) | |
| # Accuracy Quantized Numpy (6 bits) | |
| (quant_module(q_mnist_test_data).dequant().argmax(1) == mnist_test_target).mean() | |
| # Output: 0.9726 | |
| # Compute the drop in accuracy due to the quantization of the floating point value model with 6 bits of precision. | |
| np.round(100*np.abs((numpy_fc_model(mnist_test_data).argmax(1) == mnist_test_target).mean() - (quant_module(q_mnist_test_data).dequant().argmax(1) == mnist_test_target).mean()),2) | |
| # Output: 0.07% | 
Post training quantization with 6 bits quantized precision allows us to maintain the model to its full performance (only 0.07% drop in performance, which is almost negligible). As expected, all layers, activations and inputs have been converted to integers, which allows us to get rid of one of the constraints of FHE: no more difficulties in limiting us to integer multiplications. The core computation (the matrix multiplication) is in full integers and the rest (scales in floats and zero point values) can take advantage of a table lookup along with the activation function (basically done with a PBS).
Finally, in classical quantization, inputs, first layer, and last layer are often left in floating point values to gain accuracy. For FHE, we need to deal with integers even for first and last layers. This will be done by applying an affine quantization process to the input, first, and last layers as well as the rest of the network. Note that in some tasks such as NLP the input is tokenized and thus already a good representation for FHE (one-hot vector along with the learnable embedding layer).
Another challenge is that a 32-bit accumulator is commonly used in the major ML frameworks for the core operation in a neural network (multiply-add). Due to the low-precision you can expect from PBS (currently, 7 bits in Concrete Numpy), this is a much more complicated task to solve for the moment. This will soon be solved by the breakthroughs in FHE!.
Stay tuned to our blog to see how this work develops!
Thank you to Jordan Frery for his contribution to this article.
Get the latest news about homomorphic encryption and what we do at Zama: subscribe to our newsletter.
We are hiring! Join Zama and help us safeguard privacy by making the internet encrypted end-to-end. All the info here: jobs.zama.ai
We’re open source — follow Zama on Github here: github.com/zama-ai


