Skip to content

Commit c7b76d3

Browse files
committed
device: uniformly check ECDH output for zeros
For some reason, this was omitted for response messages. Reported-by: z <[email protected]> Fixes: 8c34c4c ("First set of code review patches") Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent 1e2c3e5 commit c7b76d3

File tree

5 files changed

+45
-38
lines changed

5 files changed

+45
-38
lines changed

Diff for: device/device.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
265265
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
266266
for _, peer := range device.peers.keyMap {
267267
handshake := &peer.handshake
268-
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
268+
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
269269
expiredPeers = append(expiredPeers, peer)
270270
}
271271

Diff for: device/noise-helpers.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"crypto/hmac"
1010
"crypto/rand"
1111
"crypto/subtle"
12+
"errors"
1213
"hash"
1314

1415
"golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
9495
return
9596
}
9697

97-
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
98+
var errInvalidPublicKey = errors.New("invalid public key")
99+
100+
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
98101
apk := (*[NoisePublicKeySize]byte)(&pk)
99102
ask := (*[NoisePrivateKeySize]byte)(sk)
100103
curve25519.ScalarMult(&ss, ask, apk)
101-
return ss
104+
if isZero(ss[:]) {
105+
return ss, errInvalidPublicKey
106+
}
107+
return ss, nil
102108
}

Diff for: device/noise-protocol.go

+32-31
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ func init() {
175175
}
176176

177177
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
178-
errZeroECDHResult := errors.New("ECDH returned all zeros")
179-
180178
device.staticIdentity.RLock()
181179
defer device.staticIdentity.RUnlock()
182180

@@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
204202
handshake.mixHash(msg.Ephemeral[:])
205203

206204
// encrypt static key
207-
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
208-
if isZero(ss[:]) {
209-
return nil, errZeroECDHResult
205+
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
206+
if err != nil {
207+
return nil, err
210208
}
211209
var key [chacha20poly1305.KeySize]byte
212210
KDF2(
@@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
221219

222220
// encrypt timestamp
223221
if isZero(handshake.precomputedStaticStatic[:]) {
224-
return nil, errZeroECDHResult
222+
return nil, errInvalidPublicKey
225223
}
226224
KDF2(
227225
&handshake.chainKey,
@@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
264262
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
265263

266264
// decrypt static key
267-
var err error
268265
var peerPK NoisePublicKey
269266
var key [chacha20poly1305.KeySize]byte
270-
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
271-
if isZero(ss[:]) {
267+
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
268+
if err != nil {
272269
return nil
273270
}
274271
KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
384381
handshake.mixHash(msg.Ephemeral[:])
385382
handshake.mixKey(msg.Ephemeral[:])
386383

387-
func() {
388-
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
389-
handshake.mixKey(ss[:])
390-
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
391-
handshake.mixKey(ss[:])
392-
}()
384+
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
385+
if err != nil {
386+
return nil, err
387+
}
388+
handshake.mixKey(ss[:])
389+
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
390+
if err != nil {
391+
return nil, err
392+
}
393+
handshake.mixKey(ss[:])
393394

394395
// add preshared key
395396

@@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
406407

407408
handshake.mixHash(tau[:])
408409

409-
func() {
410-
aead, _ := chacha20poly1305.New(key[:])
411-
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
412-
handshake.mixHash(msg.Empty[:])
413-
}()
410+
aead, _ := chacha20poly1305.New(key[:])
411+
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
412+
handshake.mixHash(msg.Empty[:])
414413

415414
handshake.state = handshakeResponseCreated
416415

@@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
455454
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
456455
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
457456

458-
func() {
459-
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
460-
mixKey(&chainKey, &chainKey, ss[:])
461-
setZero(ss[:])
462-
}()
457+
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
458+
if err != nil {
459+
return false
460+
}
461+
mixKey(&chainKey, &chainKey, ss[:])
462+
setZero(ss[:])
463463

464-
func() {
465-
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
466-
mixKey(&chainKey, &chainKey, ss[:])
467-
setZero(ss[:])
468-
}()
464+
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
465+
if err != nil {
466+
return false
467+
}
468+
mixKey(&chainKey, &chainKey, ss[:])
469+
setZero(ss[:])
469470

470471
// add preshared key (psk)
471472

@@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
483484
// authenticate transcript
484485

485486
aead, _ := chacha20poly1305.New(key[:])
486-
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
487+
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
487488
if err != nil {
488489
return false
489490
}

Diff for: device/noise_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
2424
pk1 := sk1.publicKey()
2525
pk2 := sk2.publicKey()
2626

27-
ss1 := sk1.sharedSecret(pk2)
28-
ss2 := sk2.sharedSecret(pk1)
27+
ss1, err1 := sk1.sharedSecret(pk2)
28+
ss2, err2 := sk2.sharedSecret(pk1)
2929

30-
if ss1 != ss2 {
30+
if ss1 != ss2 || err1 != nil || err2 != nil {
3131
t.Fatal("Failed to compute shared secet")
3232
}
3333
}

Diff for: device/peer.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
9292
// pre-compute DH
9393
handshake := &peer.handshake
9494
handshake.mutex.Lock()
95-
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
95+
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
9696
handshake.remoteStatic = pk
9797
handshake.mutex.Unlock()
9898

0 commit comments

Comments
 (0)