diff --git a/client.go b/client.go index 88bdc33fb..b935675b7 100644 --- a/client.go +++ b/client.go @@ -183,6 +183,14 @@ type Client struct { // The library is currently embedded in mautrix-meta (https://github.com/mautrix/meta), but may be separated later. MessengerConfig *MessengerConfig RefreshCAT func(context.Context) error + + // EncryptConcurrency controls how many goroutines are used when encrypting a + // message fan-out to multiple devices (e.g. large groups). If zero or + // negative, the library will pick a value automatically (runtime.NumCPU()). + // A value of 1 disables parallelization and preserves the previous + // sequential behavior. Small groups automatically fall back to sequential + // processing for lower overhead. + EncryptConcurrency int } type groupMetaCache struct { diff --git a/send.go b/send.go index e3399845b..62304348f 100644 --- a/send.go +++ b/send.go @@ -14,9 +14,11 @@ import ( "encoding/hex" "errors" "fmt" + "runtime" "sort" "strconv" "strings" + "sync" "time" "github.com/rs/zerolog" @@ -1164,74 +1166,251 @@ func (cli *Client) encryptMessageForDevices( ) ([]waBinary.Node, bool) { ownJID := cli.getOwnID() ownLID := cli.getOwnLID() - includeIdentity := false - participantNodes := make([]waBinary.Node, 0, len(allDevices)) - var retryDevices, retryEncryptionIdentities []types.JID - for _, jid := range allDevices { - plaintext := msgPlaintext - if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil { - if jid == ownJID || jid == ownLID { + + // Scoped cache activation (batch load sessions & identity keys) – only if backend supports it + if cli.Store != nil && cli.Store.Cache != nil && len(allDevices) > 1 { // only bother if >1 device + addresses := make([]string, 0, len(allDevices)) + for _, jid := range allDevices { + if jid == ownJID || jid == ownLID { // skip own identity placeholders, own device sessions not needed continue } - plaintext = dsmPlaintext + addresses = append(addresses, jid.SignalAddress().String()) } - encryptionIdentity := jid - if jid.Server == types.DefaultUserServer { - lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid) + if len(addresses) > 0 { + // Batch fetch + sessMap, err := cli.Store.Cache.GetSessions(ctx, addresses) if err != nil { - cli.Log.Warnf("Failed to get LID for %s: %v", jid, err) - } else if !lidForPN.IsEmpty() { - cli.migrateSessionStore(ctx, jid, lidForPN) - encryptionIdentity = lidForPN + cli.Log.Warnf("Scoped cache: failed to batch fetch sessions: %v", err) + sessMap = map[string][]byte{} } + idMap, err := cli.Store.Cache.GetIdentityKeys(ctx, addresses) + if err != nil { + cli.Log.Warnf("Scoped cache: failed to batch fetch identity keys: %v", err) + idMap = map[string][32]byte{} + } + cli.Store.SessionsCacheLock.Lock() + cli.Store.SessionsCache = sessMap + cli.Store.IdentityKeysCache = idMap + cli.Store.SessionsCacheLock.Unlock() + defer func() { + cli.Store.SessionsCacheLock.Lock() + sessionsCopy := cli.Store.SessionsCache + identityCopy := cli.Store.IdentityKeysCache + cli.Store.SessionsCache = nil + cli.Store.IdentityKeysCache = nil + cli.Store.SessionsCacheLock.Unlock() + // Flush new/updated sessions & identities back (best effort) + if len(sessionsCopy) > 0 { + if err := cli.Store.Cache.PutSessions(context.Background(), sessionsCopy); err != nil { + cli.Log.Warnf("Scoped cache: failed to flush sessions: %v", err) + } + } + if len(identityCopy) > 0 { + if err := cli.Store.Cache.PutIdentityKeys(context.Background(), identityCopy); err != nil { + cli.Log.Warnf("Scoped cache: failed to flush identity keys: %v", err) + } + } + }() } + } + includeIdentity := false + participantNodes := make([]waBinary.Node, 0, len(allDevices)) - encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap( - ctx, plaintext, jid, encryptionIdentity, nil, encAttrs, - ) - if errors.Is(err, ErrNoSession) { - retryDevices = append(retryDevices, jid) - retryEncryptionIdentities = append(retryEncryptionIdentities, encryptionIdentity) - continue - } else if err != nil { - // TODO return these errors if it's a fatal one (like context cancellation or database) - cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, jid, err) - continue - } + const parallelThreshold = 8 - participantNodes = append(participantNodes, *encrypted) - if isPreKey { - includeIdentity = true - } + const minConcurrency = 1 + const maxConcurrency = 64 + var concurrency int + if cli.EncryptConcurrency > 0 { + concurrency = cli.EncryptConcurrency + } else { + procs := runtime.GOMAXPROCS(0) + + concurrency = procs * 2 } - if len(retryDevices) > 0 { - bundles, err := cli.fetchPreKeys(ctx, retryDevices) - if err != nil { - cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err) - } else { - for i, jid := range retryDevices { - resp := bundles[jid] - if resp.err != nil { - cli.Log.Warnf("Failed to fetch prekey for %s: %v", jid, resp.err) + + if concurrency < minConcurrency { + concurrency = minConcurrency + } + if concurrency > maxConcurrency { + concurrency = maxConcurrency + } + + if concurrency > len(allDevices) { + concurrency = len(allDevices) + } + + if len(allDevices) == 0 { + return participantNodes, includeIdentity + } + + if len(allDevices) < parallelThreshold || concurrency <= 1 { + // Fall back to original sequential implementation for small batches + var retryDevices, retryEncryptionIdentities []types.JID + for _, jid := range allDevices { + plaintext := msgPlaintext + if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil { + if jid == ownJID || jid == ownLID { continue } - plaintext := msgPlaintext - if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil { - plaintext = dsmPlaintext - } - encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap( - ctx, plaintext, jid, retryEncryptionIdentities[i], resp.bundle, encAttrs, - ) + plaintext = dsmPlaintext + } + encryptionIdentity := jid + if jid.Server == types.DefaultUserServer { + lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid) if err != nil { - // TODO return these errors if it's a fatal one (like context cancellation or database) - cli.Log.Warnf("Failed to encrypt %s for %s (retry): %v", id, jid, err) + cli.Log.Warnf("Failed to get LID for %s: %v", jid, err) + } else if !lidForPN.IsEmpty() { + cli.migrateSessionStore(ctx, jid, lidForPN) + encryptionIdentity = lidForPN + } + } + encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap( + ctx, plaintext, jid, encryptionIdentity, nil, encAttrs, + ) + if errors.Is(err, ErrNoSession) { + retryDevices = append(retryDevices, jid) + retryEncryptionIdentities = append(retryEncryptionIdentities, encryptionIdentity) + continue + } else if err != nil { + cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, jid, err) + continue + } + participantNodes = append(participantNodes, *encrypted) + if isPreKey { + includeIdentity = true + } + } + participantNodes, includeIdentity = cli.retryEncryptMissing(ctx, id, msgPlaintext, dsmPlaintext, encAttrs, ownJID, ownLID, participantNodes, includeIdentity, retryDevices, retryEncryptionIdentities) + return participantNodes, includeIdentity + } + + type workItem struct { + jid types.JID + plaintext []byte + encryptionIdentity types.JID + } + type resultItem struct { + jid types.JID + node *waBinary.Node + isPreKey bool + retry bool + err error + encIdent types.JID + } + jobs := make(chan workItem) + results := make(chan resultItem, len(allDevices)) + var wg sync.WaitGroup + + worker := func() { + defer wg.Done() + for wi := range jobs { + encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(ctx, wi.plaintext, wi.jid, wi.encryptionIdentity, nil, encAttrs) + if errors.Is(err, ErrNoSession) { + results <- resultItem{jid: wi.jid, retry: true, err: err, encIdent: wi.encryptionIdentity} + continue + } else if err != nil { + results <- resultItem{jid: wi.jid, err: err} + continue + } + results <- resultItem{jid: wi.jid, node: encrypted, isPreKey: isPreKey} + } + } + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go worker() + } + + go func() { + for _, jid := range allDevices { + plaintext := msgPlaintext + if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil { + if jid == ownJID || jid == ownLID { continue } - participantNodes = append(participantNodes, *encrypted) - if isPreKey { - includeIdentity = true + plaintext = dsmPlaintext + } + encryptionIdentity := jid + if jid.Server == types.DefaultUserServer { + lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid) + if err != nil { + cli.Log.Warnf("Failed to get LID for %s: %v", jid, err) + } else if !lidForPN.IsEmpty() { + cli.migrateSessionStore(ctx, jid, lidForPN) + encryptionIdentity = lidForPN } } + jobs <- workItem{jid: jid, plaintext: plaintext, encryptionIdentity: encryptionIdentity} + } + close(jobs) + }() + + var retryDevices []types.JID + var retryEncryptionIdentities []types.JID + for completed := 0; completed < len(allDevices); completed++ { + res := <-results + if res.err != nil { + if res.retry { + retryDevices = append(retryDevices, res.jid) + retryEncryptionIdentities = append(retryEncryptionIdentities, res.encIdent) + } else { + cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, res.jid, res.err) + } + continue + } + if res.node != nil { + participantNodes = append(participantNodes, *res.node) + if res.isPreKey { + includeIdentity = true + } + } + } + + go func() { wg.Wait(); close(results) }() + + participantNodes, includeIdentity = cli.retryEncryptMissing(ctx, id, msgPlaintext, dsmPlaintext, encAttrs, ownJID, ownLID, participantNodes, includeIdentity, retryDevices, retryEncryptionIdentities) + return participantNodes, includeIdentity +} + +func (cli *Client) retryEncryptMissing( + ctx context.Context, + id string, + msgPlaintext, dsmPlaintext []byte, + encAttrs waBinary.Attrs, + ownJID, ownLID types.JID, + participantNodes []waBinary.Node, + includeIdentity bool, + retryDevices, retryEncryptionIdentities []types.JID, +) ([]waBinary.Node, bool) { + if len(retryDevices) == 0 { + return participantNodes, includeIdentity + } + bundles, err := cli.fetchPreKeys(ctx, retryDevices) + if err != nil { + cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err) + return participantNodes, includeIdentity + } + for i, jid := range retryDevices { + resp := bundles[jid] + if resp.err != nil { + cli.Log.Warnf("Failed to fetch prekey for %s: %v", jid, resp.err) + continue + } + plaintext := msgPlaintext + if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil { + plaintext = dsmPlaintext + } + encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap( + ctx, plaintext, jid, retryEncryptionIdentities[i], resp.bundle, encAttrs, + ) + if err != nil { + cli.Log.Warnf("Failed to encrypt %s for %s (retry): %v", id, jid, err) + continue + } + participantNodes = append(participantNodes, *encrypted) + if isPreKey { + includeIdentity = true } } return participantNodes, includeIdentity diff --git a/store/signal.go b/store/signal.go index f002e604d..8f09a8e90 100644 --- a/store/signal.go +++ b/store/signal.go @@ -7,6 +7,7 @@ package store import ( + "bytes" "context" "fmt" @@ -36,6 +37,16 @@ func (device *Device) GetLocalRegistrationID() uint32 { func (device *Device) SaveIdentity(ctx context.Context, address *protocol.SignalAddress, identityKey *identity.Key) error { addrString := address.String() + // Scoped cache active? + device.SessionsCacheLock.RLock() + cacheActive := device.IdentityKeysCache != nil + device.SessionsCacheLock.RUnlock() + if cacheActive { + device.SessionsCacheLock.Lock() + device.IdentityKeysCache[addrString] = identityKey.PublicKey().PublicKey() + device.SessionsCacheLock.Unlock() + return nil + } err := device.Identities.PutIdentity(ctx, addrString, identityKey.PublicKey().PublicKey()) if err != nil { return fmt.Errorf("failed to save identity of %s: %w", addrString, err) @@ -45,6 +56,16 @@ func (device *Device) SaveIdentity(ctx context.Context, address *protocol.Signal func (device *Device) IsTrustedIdentity(ctx context.Context, address *protocol.SignalAddress, identityKey *identity.Key) (bool, error) { addrString := address.String() + device.SessionsCacheLock.RLock() + cacheActive := device.IdentityKeysCache != nil + if cacheActive { + if existing, ok := device.IdentityKeysCache[addrString]; ok { + device.SessionsCacheLock.RUnlock() + newKey := identityKey.PublicKey().PublicKey() + return bytes.Equal(existing[:], newKey[:]), nil + } + } + device.SessionsCacheLock.RUnlock() isTrusted, err := device.Identities.IsTrustedIdentity(ctx, addrString, identityKey.PublicKey().PublicKey()) if err != nil { return false, fmt.Errorf("failed to check if %s's identity is trusted: %w", addrString, err) @@ -84,6 +105,21 @@ func (device *Device) ContainsPreKey(ctx context.Context, preKeyID uint32) (bool func (device *Device) LoadSession(ctx context.Context, address *protocol.SignalAddress) (*record.Session, error) { addrString := address.String() + device.SessionsCacheLock.RLock() + if device.SessionsCache != nil { + if raw, ok := device.SessionsCache[addrString]; ok { + device.SessionsCacheLock.RUnlock() + if len(raw) == 0 { // placeholder for new session + return record.NewSession(SignalProtobufSerializer.Session, SignalProtobufSerializer.State), nil + } + sess, err := record.NewSessionFromBytes(raw, SignalProtobufSerializer.Session, SignalProtobufSerializer.State) + if err != nil { + return nil, fmt.Errorf("failed to deserialize cached session with %s: %w", addrString, err) + } + return sess, nil + } + } + device.SessionsCacheLock.RUnlock() rawSess, err := device.Sessions.GetSession(ctx, addrString) if err != nil { return nil, fmt.Errorf("failed to load session with %s: %w", addrString, err) @@ -104,6 +140,15 @@ func (device *Device) GetSubDeviceSessions(ctx context.Context, name string) ([] func (device *Device) StoreSession(ctx context.Context, address *protocol.SignalAddress, record *record.Session) error { addrString := address.String() + device.SessionsCacheLock.RLock() + if device.SessionsCache != nil { + device.SessionsCacheLock.RUnlock() + device.SessionsCacheLock.Lock() + device.SessionsCache[addrString] = record.Serialize() + device.SessionsCacheLock.Unlock() + return nil + } + device.SessionsCacheLock.RUnlock() err := device.Sessions.PutSession(ctx, addrString, record.Serialize()) if err != nil { return fmt.Errorf("failed to store session with %s: %w", addrString, err) @@ -113,6 +158,14 @@ func (device *Device) StoreSession(ctx context.Context, address *protocol.Signal func (device *Device) ContainsSession(ctx context.Context, remoteAddress *protocol.SignalAddress) (bool, error) { addrString := remoteAddress.String() + device.SessionsCacheLock.RLock() + if device.SessionsCache != nil { + if _, ok := device.SessionsCache[addrString]; ok { + device.SessionsCacheLock.RUnlock() + return true, nil + } + } + device.SessionsCacheLock.RUnlock() hasSession, err := device.Sessions.HasSession(ctx, addrString) if err != nil { return false, fmt.Errorf("failed to check if store has session for %s: %w", addrString, err) diff --git a/store/sqlstore/container.go b/store/sqlstore/container.go index 29b36c08a..904727405 100644 --- a/store/sqlstore/container.go +++ b/store/sqlstore/container.go @@ -277,6 +277,7 @@ func (c *Container) initializeDevice(device *store.Device) { device.PrivacyTokens = innerStore device.EventBuffer = innerStore device.LIDs = c.LIDMap + device.Cache = innerStore device.Container = c device.Initialized = true } diff --git a/store/sqlstore/store.go b/store/sqlstore/store.go index be3c8dd8a..eb533fc08 100644 --- a/store/sqlstore/store.go +++ b/store/sqlstore/store.go @@ -67,6 +67,8 @@ func NewSQLStore(c *Container, jid types.JID) *SQLStore { } var _ store.AllSessionSpecificStores = (*SQLStore)(nil) +// Implement CacheStore (optional) if batch methods are used by higher level code. +var _ store.CacheStore = (*SQLStore)(nil) const ( putIdentityQuery = ` @@ -107,6 +109,122 @@ func (s *SQLStore) IsTrustedIdentity(ctx context.Context, address string, key [3 return *(*[32]byte)(existingIdentity) == key, nil } +// ---- Batch cache queries (scoped cache support) ---- +const ( + getCacheSessionsPrefix = `SELECT their_id, session FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id IN (` + getCacheIdentityKeysPrefix = `SELECT their_id, identity FROM whatsmeow_identity_keys WHERE our_jid=$1 AND their_id IN (` + // We'll build VALUES list for bulk upsert; reuse single-row upsert syntax expanded. +) + +// GetSessions fetches multiple sessions in a single query. +func (s *SQLStore) GetSessions(ctx context.Context, addresses []string) (map[string][]byte, error) { + if len(addresses) == 0 { + return map[string][]byte{}, nil + } + // Build dynamic IN clause: ($2,$3,...) + placeholders := make([]string, len(addresses)) + args := make([]any, 1+len(addresses)) + args[0] = s.JID + for i, addr := range addresses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args[i+1] = addr + } + query := getCacheSessionsPrefix + strings.Join(placeholders, ",") + ")" + rows, err := s.db.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to batch query sessions: %w", err) + } + defer rows.Close() + out := make(map[string][]byte, len(addresses)) + for rows.Next() { + var id string + var sess []byte + if err := rows.Scan(&id, &sess); err != nil { + return nil, fmt.Errorf("failed to scan session row: %w", err) + } + out[id] = sess + } + return out, nil +} + +// GetIdentityKeys fetches multiple identity keys in a single query. +func (s *SQLStore) GetIdentityKeys(ctx context.Context, addresses []string) (map[string][32]byte, error) { + if len(addresses) == 0 { + return map[string][32]byte{}, nil + } + placeholders := make([]string, len(addresses)) + args := make([]any, 1+len(addresses)) + args[0] = s.JID + for i, addr := range addresses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args[i+1] = addr + } + query := getCacheIdentityKeysPrefix + strings.Join(placeholders, ",") + ")" + rows, err := s.db.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to batch query identity keys: %w", err) + } + defer rows.Close() + out := make(map[string][32]byte, len(addresses)) + for rows.Next() { + var id string + var key []byte + if err := rows.Scan(&id, &key); err != nil { + return nil, fmt.Errorf("failed to scan identity key row: %w", err) + } + if len(key) == 32 { + out[id] = *(*[32]byte)(key) + } + } + return out, nil +} + +// PutSessions bulk upserts sessions in batches (simple approach w/ individual exec if dialect lacks multi-values?). +func (s *SQLStore) PutSessions(ctx context.Context, sessions map[string][]byte) error { + if len(sessions) == 0 { + return nil + } + // Build bulk INSERT ... ON CONFLICT + const rowWidth = 3 + values := make([]any, 0, len(sessions)*rowWidth) + parts := make([]string, 0, len(sessions)) + i := 0 + for addr, sess := range sessions { + // Skip empty new-session placeholders (len==0) to avoid writing meaningless zero-length row – but still store non-empty + values = append(values, s.JID, addr, sess) + parts = append(parts, fmt.Sprintf("($%d,$%d,$%d)", i*rowWidth+1, i*rowWidth+2, i*rowWidth+3)) + i++ + } + query := "INSERT INTO whatsmeow_sessions (our_jid, their_id, session) VALUES " + strings.Join(parts, ",") + " ON CONFLICT (our_jid, their_id) DO UPDATE SET session=excluded.session" + _, err := s.db.Exec(ctx, query, values...) + if err != nil { + return fmt.Errorf("failed to bulk upsert sessions: %w", err) + } + return nil +} + +// PutIdentityKeys bulk upserts identity keys. +func (s *SQLStore) PutIdentityKeys(ctx context.Context, identityKeys map[string][32]byte) error { + if len(identityKeys) == 0 { + return nil + } + const rowWidth = 3 + values := make([]any, 0, len(identityKeys)*rowWidth) + parts := make([]string, 0, len(identityKeys)) + i := 0 + for addr, key := range identityKeys { + values = append(values, s.JID, addr, key[:]) + parts = append(parts, fmt.Sprintf("($%d,$%d,$%d)", i*rowWidth+1, i*rowWidth+2, i*rowWidth+3)) + i++ + } + query := "INSERT INTO whatsmeow_identity_keys (our_jid, their_id, identity) VALUES " + strings.Join(parts, ",") + " ON CONFLICT (our_jid, their_id) DO UPDATE SET identity=excluded.identity" + _, err := s.db.Exec(ctx, query, values...) + if err != nil { + return fmt.Errorf("failed to bulk upsert identity keys: %w", err) + } + return nil +} + const ( getSessionQuery = `SELECT session FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id=$2` hasSessionQuery = `SELECT true FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id=$2` diff --git a/store/store.go b/store/store.go index 7d07068eb..813afbe0c 100644 --- a/store/store.go +++ b/store/store.go @@ -9,6 +9,7 @@ package store import ( "context" + "sync" "time" "github.com/google/uuid" @@ -129,6 +130,16 @@ type PrivacyTokenStore interface { GetPrivacyToken(ctx context.Context, user types.JID) (*PrivacyToken, error) } +// CacheStore provides batch operations for sessions & identity keys used by the +// scoped per-SendMessage cache. Implementations are optional – if nil, the +// scoped cache optimization will be skipped transparently. +type CacheStore interface { + GetSessions(ctx context.Context, addresses []string) (map[string][]byte, error) + GetIdentityKeys(ctx context.Context, addresses []string) (map[string][32]byte, error) + PutSessions(ctx context.Context, sessions map[string][]byte) error + PutIdentityKeys(ctx context.Context, identityKeys map[string][32]byte) error +} + type BufferedEvent struct { Plaintext []byte InsertTime time.Time @@ -212,6 +223,18 @@ type Device struct { EventBuffer EventBuffer LIDs LIDStore Container DeviceContainer + + // Cache is the optional batch-capable backend implementation. + Cache CacheStore + + // The following fields implement a scoped, per-SendMessage in-memory cache + // for libsignal sessions & identity keys to reduce database round trips for + // large fan-out (e.g. big groups). They are activated by encryptMessageForDevices + // and cleared (and flushed) at the end of the call. Access must be guarded + // by SessionsCacheLock as the encryption itself may run in parallel. + SessionsCache map[string][]byte + IdentityKeysCache map[string][32]byte + SessionsCacheLock sync.RWMutex } func (device *Device) GetJID() types.JID {