Implementation of the "BitNet: Scaling 1-bit Transformers for Large Language Models"
BitLinear = tensor -> layernorm -> Binarize -> abs max quantization
pip install bitnet
import torch
from bitnet import BitLinear
from bitnet.main import Transformer
#example 1
x = torch.randn(10, 512)
layer = BitLinear(512)
y, dequant = layer(x)
print(y, dequant)
#example 2
x = torch.randn(1, 1, 10, 512)
layer = Transformer(512, 8, 8, 64)
y = layer(x)
print(y)
MIT
- Fix transformer pass error issue