Skip to content

Commit

Permalink
device: uniformly check ECDH output for zeros
Browse files Browse the repository at this point in the history
For some reason, this was omitted for response messages.

Reported-by: z <[email protected]>
Fixes: 8c34c4c ("First set of code review patches")
Signed-off-by: Jason A. Donenfeld <[email protected]>
  • Loading branch information
zx2c4 committed Feb 16, 2023
1 parent 1e2c3e5 commit c7b76d3
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 38 deletions.
2 changes: 1 addition & 1 deletion device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}

Expand Down
10 changes: 8 additions & 2 deletions device/noise-helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
"errors"
"hash"

"golang.org/x/crypto/blake2s"
Expand Down Expand Up @@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return
}

func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
var errInvalidPublicKey = errors.New("invalid public key")

func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
return ss
if isZero(ss[:]) {
return ss, errInvalidPublicKey
}
return ss, nil
}
63 changes: 32 additions & 31 deletions device/noise-protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ func init() {
}

func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
errZeroECDHResult := errors.New("ECDH returned all zeros")

device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()

Expand Down Expand Up @@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:])

// encrypt static key
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) {
return nil, errZeroECDHResult
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
}
var key [chacha20poly1305.KeySize]byte
KDF2(
Expand All @@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e

// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult
return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
Expand Down Expand Up @@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])

// decrypt static key
var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) {
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if err != nil {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
Expand Down Expand Up @@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])

func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
handshake.mixKey(ss[:])
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.mixKey(ss[:])
}()
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])

// add preshared key

Expand All @@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error

handshake.mixHash(tau[:])

func() {
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.mixHash(msg.Empty[:])
}()
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.mixHash(msg.Empty[:])

handshake.state = handshakeResponseCreated

Expand Down Expand Up @@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])

func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
if err != nil {
return false
}
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])

func() {
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if err != nil {
return false
}
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])

// add preshared key (psk)

Expand All @@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript

aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
Expand Down
6 changes: 3 additions & 3 deletions device/noise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()

ss1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1)
ss1, err1 := sk1.sharedSecret(pk2)
ss2, err2 := sk2.sharedSecret(pk1)

if ss1 != ss2 {
if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet")
}
}
Expand Down
2 changes: 1 addition & 1 deletion device/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()

Expand Down

0 comments on commit c7b76d3

Please sign in to comment.