Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

crypto/rsa: port PrivateKey.Validate to bigmod, add validations #70236

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/crypto/ecdsa/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ func signNISTEC[Point nistPoint[Point]](c *nistCurve[Point], priv *PrivateKey, c
if err != nil {
return nil, err
}
s.Mul(r, c.N)
s.MulMod(r, c.N)
s.Add(e, c.N)
s.Mul(kInv, c.N)
s.MulMod(kInv, c.N)

// Again, the chance of this happening is cryptographically negligible.
if s.IsZero() == 1 {
Expand Down Expand Up @@ -528,12 +528,12 @@ func verifyNISTEC[Point nistPoint[Point]](c *nistCurve[Point], pub *PublicKey, h
inverse(c, w, s)

// p₁ = [e * s⁻¹]G
p1, err := c.newPoint().ScalarBaseMult(e.Mul(w, c.N).Bytes(c.N))
p1, err := c.newPoint().ScalarBaseMult(e.MulMod(w, c.N).Bytes(c.N))
if err != nil {
return false
}
// p₂ = [r * s⁻¹]Q
p2, err := Q.ScalarMult(Q, w.Mul(r, c.N).Bytes(c.N))
p2, err := Q.ScalarMult(Q, w.MulMod(r, c.N).Bytes(c.N))
if err != nil {
return false
}
Expand Down
94 changes: 72 additions & 22 deletions src/crypto/internal/bigmod/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ func (x *Nat) reset(n int) *Nat {
return x
}

// set assigns x = y, optionally resizing x to the appropriate size.
func (x *Nat) set(y *Nat) *Nat {
// Set assigns x = y, optionally resizing x to the appropriate size.
func (x *Nat) Set(y *Nat) *Nat {
x.reset(len(y.limbs))
copy(x.limbs, y.limbs)
return x
Expand Down Expand Up @@ -226,6 +226,29 @@ func (x *Nat) IsZero() choice {
//
// Both operands must have the same announced length.
func (x *Nat) cmpGeq(y *Nat) choice {
c := x.subCarry(y)
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
return not(choice(c))
}

// Cmp compares x and y and returns the result of the compare:
// 1 if x > y
// 0 if x == y
// -1 if x < y
func (x *Nat) Cmp(y *Nat) int {
if x.Equal(y) == yes {
return 0
}
c := x.subCarry(y)
res := 1
if c > 0 {
res = -1
}
return res
}

func (x *Nat) subCarry(y *Nat) uint {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
Expand All @@ -235,9 +258,7 @@ func (x *Nat) cmpGeq(y *Nat) choice {
for i := 0; i < size; i++ {
_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
}
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
return not(choice(c))
return c
}

// assign sets x <- y if on == 1, and does nothing otherwise.
Expand Down Expand Up @@ -274,7 +295,7 @@ func (x *Nat) add(y *Nat) (c uint) {
// sub computes x -= y. It returns the borrow of the subtraction.
//
// Both operands must have the same announced length.
func (x *Nat) sub(y *Nat) (c uint) {
func (x *Nat) Sub(y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
Expand All @@ -301,6 +322,7 @@ type Modulus struct {
leading int // number of leading zeros in the modulus
m0inv uint // -nat.limbs[0]⁻¹ mod _W
rr *Nat // R*R for montgomeryRepresentation
even bool
}

// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
Expand Down Expand Up @@ -374,16 +396,18 @@ func minusInverseModW(x uint) uint {
// The Int must be odd. The number of significant bits (and nothing else) is
// leaked through timing side-channels.
func NewModulusFromBig(n *big.Int) (*Modulus, error) {
if b := n.Bits(); len(b) == 0 {
b := n.Bits()
if len(b) == 0 {
return nil, errors.New("modulus must be >= 0")
} else if b[0]&1 != 1 {
return nil, errors.New("modulus must be odd")
}
m := &Modulus{}
m.even = b[0]&1 != 1
m.nat = NewNat().setBig(n)
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
if !m.even {
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
}
return m, nil
}

Expand Down Expand Up @@ -508,8 +532,8 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
//
// x and m operands must have the same announced length.
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
t := NewNat().set(x)
underflow := t.sub(m.nat)
t := NewNat().Set(x)
underflow := t.Sub(m.nat)
// We keep the result if x - m didn't underflow (meaning x >= m)
// or if always was set.
keep := not(choice(underflow)) | choice(always)
Expand All @@ -520,10 +544,10 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
underflow := x.sub(y)
func (x *Nat) SubMod(y *Nat, m *Modulus) *Nat {
underflow := x.Sub(y)
// If the subtraction underflowed, add m.
t := NewNat().set(x)
t := NewNat().Set(x)
t.add(m.nat)
x.assign(choice(underflow), t)
return x
Expand Down Expand Up @@ -571,6 +595,9 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
// All inputs should be the same length and already reduced modulo m.
// x will be resized to the size of m and overwritten.
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
if m.even {
panic("crypto/rsa: montgomery multiplication on even modulus")
}
n := len(m.nat.limbs)
mLimbs := m.nat.limbs[:n]
aLimbs := a.limbs[:n]
Expand Down Expand Up @@ -707,17 +734,40 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
return carry
}

// Mul calculates x = x * y mod m.
// MulMod calculates x = x * y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
func (x *Nat) MulMod(y *Nat, m *Modulus) *Nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
}

// Mul calculates z = x * y.
//
// All inputs should be the same length and already reduced modulo m.
// z will be resized to the size of m and overwritten.
func (z *Nat) Mul(x *Nat, y *Nat, m *Modulus) *Nat {
n := len(m.nat.limbs)
zLimbs := z.resetFor(m).limbs
xLimbs := x.limbs
yLimbs := y.limbs
switch n {
default:
for i := 0; i < n; i++ {
addMulVVW(zLimbs[i:], xLimbs, yLimbs[i])
}
case 2048 / _W:
const n = 2048 / _W // compiler hint
for i := 0; i < n; i++ {
addMulVVW2048(&zLimbs[i:][0], &xLimbs[0], yLimbs[i])
}
}
return z
}

// Exp calculates out = x^e mod m.
//
// The exponent e is represented in big-endian order. The output will be resized
Expand All @@ -734,7 +784,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
}
table[0].set(x).montgomeryRepresentation(m)
table[0].Set(x).montgomeryRepresentation(m)
for i := 1; i < len(table); i++ {
table[i].montgomeryMul(table[i-1], table[0], m)
}
Expand Down Expand Up @@ -775,8 +825,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
// For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
// chain, skipping the initial run of zeroes.
xR := NewNat().set(x).montgomeryRepresentation(m)
out.set(xR)
xR := NewNat().Set(x).montgomeryRepresentation(m)
out.Set(xR)
for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ {
out.montgomeryMul(out, out, m)
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
Expand Down
70 changes: 51 additions & 19 deletions src/crypto/internal/bigmod/nat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {

func testModAddCommutative(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
aPlusB := new(Nat).set(a)
aPlusB := new(Nat).Set(a)
aPlusB.Add(b, m)
bPlusA := new(Nat).set(b)
bPlusA := new(Nat).Set(b)
bPlusA.Add(a, m)
return aPlusB.Equal(bPlusA) == 1
}
Expand All @@ -51,8 +51,8 @@ func TestModAddCommutative(t *testing.T) {

func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
original := new(Nat).set(a)
a.Sub(b, m)
original := new(Nat).Set(a)
a.SubMod(b, m)
a.Add(b, m)
return a.Equal(original) == 1
}
Expand All @@ -71,9 +71,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
aPlusOne := new(big.Int).SetBytes(natBytes(a))
aPlusOne.Add(aPlusOne, big.NewInt(1))
m, _ := NewModulusFromBig(aPlusOne)
monty := new(Nat).set(a)
monty := new(Nat).Set(a)
monty.montgomeryRepresentation(m)
aAgain := new(Nat).set(monty)
aAgain := new(Nat).Set(monty)
aAgain.montgomeryMul(monty, one, m)
if a.Equal(aAgain) != 1 {
t.Errorf("%v != %v", a, aAgain)
Expand Down Expand Up @@ -131,9 +131,12 @@ func TestModulusAndNatSizes(t *testing.T) {
// modulus strips leading zeroes and nat does not.
m := modulusFromBytes([]byte{
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
})
xb := []byte{
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
}
natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
NewNat().SetBytes(xb, m)
}
Expand Down Expand Up @@ -260,12 +263,12 @@ func TestModSub(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{6}}
y := &Nat{[]uint{7}}
x.Sub(y, m)
x.SubMod(y, m)
expected := &Nat{[]uint{12}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
}
x.Sub(y, m)
x.SubMod(y, m)
expected = &Nat{[]uint{5}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
Expand Down Expand Up @@ -323,7 +326,7 @@ func TestMulReductions(t *testing.T) {
A := NewNat().setBig(a).ExpandFor(N)
B := NewNat().setBig(b).ExpandFor(N)

if A.Mul(B, N).IsZero() != 1 {
if A.MulMod(B, N).IsZero() != 1 {
t.Error("a * b mod (a * b) != 0")
}

Expand All @@ -333,7 +336,7 @@ func TestMulReductions(t *testing.T) {
I := NewNat().setBig(i).ExpandFor(N)
one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)

if A.Mul(I, N).Equal(one) != 1 {
if A.MulMod(I, N).Equal(one) != 1 {
t.Error("a * inv(a) mod b != 1")
}
}
Expand Down Expand Up @@ -401,7 +404,7 @@ func BenchmarkModSub(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Sub(y, m)
x.SubMod(y, m)
}
}

Expand Down Expand Up @@ -434,7 +437,7 @@ func BenchmarkModMul(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul(y, m)
x.MulMod(y, m)
}
}

Expand Down Expand Up @@ -472,9 +475,38 @@ func TestNewModFromBigZero(t *testing.T) {
t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
}

expected = "modulus must be odd"
_, err = NewModulusFromBig(big.NewInt(2))
if err == nil || err.Error() != expected {
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
defer func(t *testing.T) {
if r := recover(); r != nil {
if s, ok := r.(string); !ok || !strings.Contains(s, "montgomery multiplication on even modulus") {
t.Errorf("Unexpected panic: %#v", r)
}
} else {
t.Error("Expected panic to be recovered, got nothing.")
}
}(t)

m, err := NewModulusFromBig(big.NewInt(10))
if err != nil {
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, "")
}
x := NewNat().setBig(big.NewInt(1))
y := NewNat().setBig(big.NewInt(2))
x.MulMod(y, m)
}

func TestNatCmp(t *testing.T) {
testcases := [][3]int64{
{33, 22, 1},
{33, 33, 0},
{22, 33, -1},
}
for _, tc := range testcases {
a := new(big.Int).SetInt64(tc[0])
b := new(big.Int).SetInt64(tc[1])
na := natFromBytes(a.Bytes())
nb := natFromBytes(b.Bytes())
if res := na.Cmp(nb); res != int(tc[2]) {
t.Errorf("expected %d got %d for (%d).cmp(%d)", tc[2], res, tc[0], tc[1])
}
}
}
Loading