Skip to content

Commit f7bcc26

Browse files
committed
(DO NOT MERGE) crypto: allow run goolm side-by-side with libolm
Signed-off-by: Sumner Evans <[email protected]>
1 parent 77adb6d commit f7bcc26

File tree

7 files changed

+132
-24
lines changed

7 files changed

+132
-24
lines changed

Diff for: crypto/account.go

+80-12
Original file line numberDiff line numberDiff line change
@@ -7,64 +7,107 @@
77
package crypto
88

99
import (
10+
"bytes"
1011
"encoding/json"
12+
"fmt"
1113

1214
"github.com/tidwall/sjson"
1315

1416
"maunium.net/go/mautrix"
1517
"maunium.net/go/mautrix/crypto/canonicaljson"
18+
"maunium.net/go/mautrix/crypto/goolm/account"
1619
"maunium.net/go/mautrix/crypto/olm"
1720
"maunium.net/go/mautrix/crypto/signatures"
1821
"maunium.net/go/mautrix/id"
1922
)
2023

2124
type OlmAccount struct {
22-
Internal olm.Account
25+
InternalLibolm olm.Account
26+
InternalGoolm olm.Account
2327
signingKey id.SigningKey
2428
identityKey id.IdentityKey
2529
Shared bool
2630
KeyBackupVersion id.KeyBackupVersion
2731
}
2832

2933
func NewOlmAccount() *OlmAccount {
30-
account, err := olm.NewAccount()
34+
libolmAccount, err := olm.NewAccount()
35+
if err != nil {
36+
panic(err)
37+
}
38+
pickled, err := libolmAccount.Pickle([]byte("key"))
39+
if err != nil {
40+
panic(err)
41+
}
42+
goolmAccount, err := account.AccountFromPickled(pickled, []byte("key"))
3143
if err != nil {
3244
panic(err)
3345
}
3446
return &OlmAccount{
35-
Internal: account,
47+
InternalLibolm: libolmAccount,
48+
InternalGoolm: goolmAccount,
3649
}
3750
}
3851

3952
func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) {
4053
if len(account.signingKey) == 0 || len(account.identityKey) == 0 {
4154
var err error
42-
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
55+
account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys()
56+
if err != nil {
57+
panic(err)
58+
}
59+
goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys()
4360
if err != nil {
4461
panic(err)
4562
}
63+
if account.signingKey != goolmSigningKey {
64+
panic("account signing keys not equal")
65+
}
66+
if account.identityKey != goolmIdentityKey {
67+
panic("account identity keys not equal")
68+
}
4669
}
4770
return account.signingKey, account.identityKey
4871
}
4972

5073
func (account *OlmAccount) SigningKey() id.SigningKey {
5174
if len(account.signingKey) == 0 {
5275
var err error
53-
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
76+
account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys()
77+
if err != nil {
78+
panic(err)
79+
}
80+
goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys()
5481
if err != nil {
5582
panic(err)
5683
}
84+
if account.signingKey != goolmSigningKey {
85+
panic("account signing keys not equal")
86+
}
87+
if account.identityKey != goolmIdentityKey {
88+
panic("account identity keys not equal")
89+
}
5790
}
5891
return account.signingKey
5992
}
6093

6194
func (account *OlmAccount) IdentityKey() id.IdentityKey {
6295
if len(account.identityKey) == 0 {
6396
var err error
64-
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
97+
account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys()
98+
if err != nil {
99+
panic(err)
100+
}
101+
goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys()
65102
if err != nil {
66103
panic(err)
67104
}
105+
if account.signingKey != goolmSigningKey {
106+
panic("account signing keys not equal")
107+
}
108+
if account.identityKey != goolmIdentityKey {
109+
panic("account identity keys not equal")
110+
}
68111
}
69112
return account.identityKey
70113
}
@@ -78,7 +121,15 @@ func (account *OlmAccount) SignJSON(obj any) (string, error) {
78121
}
79122
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
80123
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
81-
signed, err := account.Internal.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
124+
signed, err := account.InternalLibolm.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
125+
goolmSigned, goolmErr := account.InternalGoolm.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
126+
if err != nil {
127+
if goolmErr == nil {
128+
panic("libolm errored, but goolm did not on account.SignJSON")
129+
}
130+
} else if !bytes.Equal(signed, goolmSigned) {
131+
panic("libolm and goolm signed are not equal in account.SignJSON")
132+
}
82133
return string(signed), err
83134
}
84135

@@ -102,19 +153,36 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID
102153
return deviceKeys
103154
}
104155

105-
func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey {
106-
newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount
156+
func (a *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey {
157+
newCount := int(a.InternalLibolm.MaxNumberOfOneTimeKeys()/2) - currentOTKCount
107158
if newCount > 0 {
108-
account.Internal.GenOneTimeKeys(uint(newCount))
159+
a.InternalLibolm.GenOneTimeKeys(uint(newCount))
160+
161+
pickled, err := a.InternalLibolm.Pickle([]byte("key"))
162+
if err != nil {
163+
panic(err)
164+
}
165+
a.InternalGoolm, err = account.AccountFromPickled(pickled, []byte("key"))
166+
if err != nil {
167+
panic(err)
168+
}
109169
}
110170
oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey)
111-
internalKeys, err := account.Internal.OneTimeKeys()
171+
internalKeys, err := a.InternalLibolm.OneTimeKeys()
172+
if err != nil {
173+
panic(err)
174+
}
175+
goolmInternalKeys, err := a.InternalGoolm.OneTimeKeys()
112176
if err != nil {
113177
panic(err)
114178
}
115179
for keyID, key := range internalKeys {
180+
if goolmInternalKeys[keyID] != key {
181+
panic(fmt.Sprintf("key %s not found in getOneTimeKeys", keyID))
182+
}
183+
116184
key := mautrix.OneTimeKey{Key: key}
117-
signature, _ := account.SignJSON(key)
185+
signature, _ := a.SignJSON(key)
118186
key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature)
119187
key.IsSigned = true
120188
oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key

Diff for: crypto/encryptolm.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,17 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
116116
log.Error().Err(err).Msg("Failed to verify signature of one-time key")
117117
} else if !ok {
118118
log.Warn().Msg("One-time key has invalid signature from device")
119-
} else if sess, err := mach.account.Internal.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil {
119+
} else if sess, err := mach.account.InternalLibolm.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil {
120120
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
121121
} else {
122+
goolmSess, err := mach.account.InternalGoolm.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key)
123+
if err != nil {
124+
panic("goolm NewOutboundSession errored")
125+
}
126+
if sess.Describe() != goolmSess.Describe() {
127+
panic("goolm NewOutboundSession and libolm NewOutboundSession returned different values")
128+
}
129+
122130
wrapped := wrapSession(sess)
123131
err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped)
124132
if err != nil {

Diff for: crypto/goolm/account/account.go

+2
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,10 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
336336
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
337337
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
338338
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
339+
fmt.Printf("123 %+v\n", err)
339340
return err
340341
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair
342+
fmt.Printf("456 %+v\n", err)
341343
return err
342344
}
343345

Diff for: crypto/machine.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.O
279279
mach.receivedOTKsForSelf.Store(true)
280280
}
281281

282-
minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
282+
minCount := mach.account.InternalLibolm.MaxNumberOfOneTimeKeys() / 2
283283
if otkCount.SignedCurve25519 < int(minCount) {
284284
traceID := time.Now().Format("15:04:05.000000")
285285
log := mach.Log.With().Str("trace_id", traceID).Logger()
@@ -749,7 +749,8 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
749749
return err
750750
}
751751
mach.lastOTKUpload = time.Now()
752-
mach.account.Internal.MarkKeysAsPublished()
752+
mach.account.InternalLibolm.MarkKeysAsPublished()
753+
mach.account.InternalGoolm.MarkKeysAsPublished()
753754
mach.account.Shared = true
754755
return mach.saveAccount(ctx)
755756
}

Diff for: crypto/machine_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,11 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
7878
otk = otkTmp
7979
break
8080
}
81-
machineIn.account.Internal.MarkKeysAsPublished()
81+
machineIn.account.InternalLibolm.MarkKeysAsPublished()
82+
machineIn.account.InternalGoolm.MarkKeysAsPublished()
8283

8384
// create outbound olm session for sending machine using OTK
84-
olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key)
85+
olmSession, err := machineOut.account.InternalLibolm.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key)
8586
if err != nil {
8687
t.Errorf("Failed to create outbound olm session: %v", err)
8788
}

Diff for: crypto/sessions.go

+14-3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
package crypto
88

99
import (
10+
"bytes"
1011
"errors"
1112
"time"
1213

14+
"go.mau.fi/util/exerrors"
15+
1316
"maunium.net/go/mautrix/crypto/olm"
1417
"maunium.net/go/mautrix/event"
15-
1618
"maunium.net/go/mautrix/id"
1719
)
1820

@@ -68,11 +70,20 @@ func wrapSession(session olm.Session) *OlmSession {
6870
}
6971

7072
func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) {
71-
session, err := account.Internal.NewInboundSessionFrom(&senderKey, ciphertext)
73+
session, err := account.InternalLibolm.NewInboundSessionFrom(&senderKey, ciphertext)
7274
if err != nil {
7375
return nil, err
7476
}
75-
_ = account.Internal.RemoveOneTimeKeys(session)
77+
goolmSession, err := account.InternalGoolm.NewInboundSessionFrom(&senderKey, ciphertext)
78+
if err != nil {
79+
return nil, err
80+
}
81+
if !bytes.Equal(exerrors.Must(goolmSession.Pickle([]byte("123"))), exerrors.Must(session.Pickle([]byte("123")))) {
82+
panic("goolm inbound session and libolm inbound session from ciphertext are different")
83+
}
84+
85+
_ = account.InternalLibolm.RemoveOneTimeKeys(session)
86+
_ = account.InternalGoolm.RemoveOneTimeKeys(goolmSession)
7687
return wrapSession(session), nil
7788
}
7889

Diff for: crypto/sql_store.go

+21-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package crypto
88

99
import (
10+
"bytes"
1011
"context"
1112
"database/sql"
1213
"database/sql/driver"
@@ -21,6 +22,7 @@ import (
2122
"go.mau.fi/util/dbutil"
2223

2324
"maunium.net/go/mautrix"
25+
"maunium.net/go/mautrix/crypto/goolm/account"
2426
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
2527
"maunium.net/go/mautrix/crypto/olm"
2628
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
@@ -127,35 +129,50 @@ func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.Devi
127129
// PutAccount stores an OlmAccount in the database.
128130
func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error {
129131
store.Account = account
130-
bytes, err := account.Internal.Pickle(store.PickleKey)
132+
pickled, err := account.InternalLibolm.Pickle(store.PickleKey)
131133
if err != nil {
132134
return err
133135
}
136+
goolmBytes, err := account.InternalGoolm.Pickle(store.PickleKey)
137+
if err != nil {
138+
panic(fmt.Errorf("pickling goolm account errored %w", err))
139+
}
140+
if !bytes.Equal(pickled, goolmBytes) {
141+
panic("libolm and goolm pickled to different values")
142+
}
134143
_, err = store.DB.Exec(ctx, `
135144
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6)
136145
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
137146
account=excluded.account, account_id=excluded.account_id,
138147
key_backup_version=excluded.key_backup_version
139-
`, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID, account.KeyBackupVersion)
148+
`, store.DeviceID, account.Shared, store.SyncToken, pickled, store.AccountID, account.KeyBackupVersion)
140149
return err
141150
}
142151

143152
// GetAccount retrieves an OlmAccount from the database.
144153
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
145154
if store.Account == nil {
146155
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID)
147-
acc := &OlmAccount{Internal: olm.NewBlankAccount()}
156+
acc := &OlmAccount{
157+
InternalLibolm: olm.NewBlankAccount(),
158+
InternalGoolm: &account.Account{},
159+
}
148160
var accountBytes []byte
149161
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion)
150162
if err == sql.ErrNoRows {
151163
return nil, nil
152164
} else if err != nil {
153165
return nil, err
154166
}
155-
err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
167+
err = acc.InternalLibolm.Unpickle(accountBytes, store.PickleKey)
156168
if err != nil {
157169
return nil, err
158170
}
171+
fmt.Printf("%s\n", accountBytes)
172+
err = acc.InternalGoolm.Unpickle(accountBytes, store.PickleKey)
173+
if err != nil {
174+
panic(fmt.Sprintf("failed to unpickle account using goolm: %+v", err))
175+
}
159176
store.Account = acc
160177
}
161178
return store.Account, nil

0 commit comments

Comments
 (0)