Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ type Client struct {
// The library is currently embedded in mautrix-meta (https://github.com/mautrix/meta), but may be separated later.
MessengerConfig *MessengerConfig
RefreshCAT func(context.Context) error

// EncryptConcurrency controls how many goroutines are used when encrypting a
// message fan-out to multiple devices (e.g. large groups). If zero or
// negative, the library will pick a value automatically (runtime.NumCPU()).
// A value of 1 disables parallelization and preserves the previous
// sequential behavior. Small groups automatically fall back to sequential
// processing for lower overhead.
EncryptConcurrency int
}

type groupMetaCache struct {
Expand Down
283 changes: 231 additions & 52 deletions send.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"encoding/hex"
"errors"
"fmt"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/rs/zerolog"
Expand Down Expand Up @@ -1164,74 +1166,251 @@ func (cli *Client) encryptMessageForDevices(
) ([]waBinary.Node, bool) {
ownJID := cli.getOwnID()
ownLID := cli.getOwnLID()
includeIdentity := false
participantNodes := make([]waBinary.Node, 0, len(allDevices))
var retryDevices, retryEncryptionIdentities []types.JID
for _, jid := range allDevices {
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
if jid == ownJID || jid == ownLID {

// Scoped cache activation (batch load sessions & identity keys) – only if backend supports it
if cli.Store != nil && cli.Store.Cache != nil && len(allDevices) > 1 { // only bother if >1 device
addresses := make([]string, 0, len(allDevices))
for _, jid := range allDevices {
if jid == ownJID || jid == ownLID { // skip own identity placeholders, own device sessions not needed
continue
}
plaintext = dsmPlaintext
addresses = append(addresses, jid.SignalAddress().String())
}
encryptionIdentity := jid
if jid.Server == types.DefaultUserServer {
lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid)
if len(addresses) > 0 {
// Batch fetch
sessMap, err := cli.Store.Cache.GetSessions(ctx, addresses)
if err != nil {
cli.Log.Warnf("Failed to get LID for %s: %v", jid, err)
} else if !lidForPN.IsEmpty() {
cli.migrateSessionStore(ctx, jid, lidForPN)
encryptionIdentity = lidForPN
cli.Log.Warnf("Scoped cache: failed to batch fetch sessions: %v", err)
sessMap = map[string][]byte{}
}
idMap, err := cli.Store.Cache.GetIdentityKeys(ctx, addresses)
if err != nil {
cli.Log.Warnf("Scoped cache: failed to batch fetch identity keys: %v", err)
idMap = map[string][32]byte{}
}
cli.Store.SessionsCacheLock.Lock()
cli.Store.SessionsCache = sessMap
cli.Store.IdentityKeysCache = idMap
cli.Store.SessionsCacheLock.Unlock()
defer func() {
cli.Store.SessionsCacheLock.Lock()
sessionsCopy := cli.Store.SessionsCache
identityCopy := cli.Store.IdentityKeysCache
cli.Store.SessionsCache = nil
cli.Store.IdentityKeysCache = nil
cli.Store.SessionsCacheLock.Unlock()
// Flush new/updated sessions & identities back (best effort)
if len(sessionsCopy) > 0 {
if err := cli.Store.Cache.PutSessions(context.Background(), sessionsCopy); err != nil {
cli.Log.Warnf("Scoped cache: failed to flush sessions: %v", err)
}
}
if len(identityCopy) > 0 {
if err := cli.Store.Cache.PutIdentityKeys(context.Background(), identityCopy); err != nil {
cli.Log.Warnf("Scoped cache: failed to flush identity keys: %v", err)
}
}
}()
}
}
includeIdentity := false
participantNodes := make([]waBinary.Node, 0, len(allDevices))

encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(
ctx, plaintext, jid, encryptionIdentity, nil, encAttrs,
)
if errors.Is(err, ErrNoSession) {
retryDevices = append(retryDevices, jid)
retryEncryptionIdentities = append(retryEncryptionIdentities, encryptionIdentity)
continue
} else if err != nil {
// TODO return these errors if it's a fatal one (like context cancellation or database)
cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, jid, err)
continue
}
const parallelThreshold = 8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 How was this number set? So, if I had 9 devices, you'd start a new routine for however many EncryptConsistency is (lets say I have 4 cpus).. Have you consider if that's faster than just letting the same 1 routine process just that one more device??


participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
}
const minConcurrency = 1
const maxConcurrency = 64
var concurrency int
if cli.EncryptConcurrency > 0 {
concurrency = cli.EncryptConcurrency
} else {
procs := runtime.GOMAXPROCS(0)

concurrency = procs * 2
}
if len(retryDevices) > 0 {
bundles, err := cli.fetchPreKeys(ctx, retryDevices)
if err != nil {
cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err)
} else {
for i, jid := range retryDevices {
resp := bundles[jid]
if resp.err != nil {
cli.Log.Warnf("Failed to fetch prekey for %s: %v", jid, resp.err)

if concurrency < minConcurrency {
concurrency = minConcurrency
}
if concurrency > maxConcurrency {
concurrency = maxConcurrency
}

if concurrency > len(allDevices) {
concurrency = len(allDevices)
}

if len(allDevices) == 0 {
return participantNodes, includeIdentity
}

if len(allDevices) < parallelThreshold || concurrency <= 1 {
// Fall back to original sequential implementation for small batches
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to minimize the amount of duplicated code here?
Maybe abstract this into a function that can be routined or ran sequentially if threading is not possible??

var retryDevices, retryEncryptionIdentities []types.JID
for _, jid := range allDevices {
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
if jid == ownJID || jid == ownLID {
continue
}
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
plaintext = dsmPlaintext
}
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(
ctx, plaintext, jid, retryEncryptionIdentities[i], resp.bundle, encAttrs,
)
plaintext = dsmPlaintext
}
encryptionIdentity := jid
if jid.Server == types.DefaultUserServer {
lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid)
if err != nil {
// TODO return these errors if it's a fatal one (like context cancellation or database)
cli.Log.Warnf("Failed to encrypt %s for %s (retry): %v", id, jid, err)
cli.Log.Warnf("Failed to get LID for %s: %v", jid, err)
} else if !lidForPN.IsEmpty() {
cli.migrateSessionStore(ctx, jid, lidForPN)
encryptionIdentity = lidForPN
}
}
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(
ctx, plaintext, jid, encryptionIdentity, nil, encAttrs,
)
if errors.Is(err, ErrNoSession) {
retryDevices = append(retryDevices, jid)
retryEncryptionIdentities = append(retryEncryptionIdentities, encryptionIdentity)
continue
} else if err != nil {
cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, jid, err)
continue
}
participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
}
}
participantNodes, includeIdentity = cli.retryEncryptMissing(ctx, id, msgPlaintext, dsmPlaintext, encAttrs, ownJID, ownLID, participantNodes, includeIdentity, retryDevices, retryEncryptionIdentities)
return participantNodes, includeIdentity
}

type workItem struct {
jid types.JID
plaintext []byte
encryptionIdentity types.JID
}
type resultItem struct {
jid types.JID
node *waBinary.Node
isPreKey bool
retry bool
err error
encIdent types.JID
}
jobs := make(chan workItem)
results := make(chan resultItem, len(allDevices))
var wg sync.WaitGroup

worker := func() {
defer wg.Done()
for wi := range jobs {
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(ctx, wi.plaintext, wi.jid, wi.encryptionIdentity, nil, encAttrs)
if errors.Is(err, ErrNoSession) {
results <- resultItem{jid: wi.jid, retry: true, err: err, encIdent: wi.encryptionIdentity}
continue
} else if err != nil {
results <- resultItem{jid: wi.jid, err: err}
continue
}
results <- resultItem{jid: wi.jid, node: encrypted, isPreKey: isPreKey}
}
}

for i := 0; i < concurrency; i++ {
wg.Add(1)
go worker()
}

go func() {
for _, jid := range allDevices {
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
if jid == ownJID || jid == ownLID {
continue
}
participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
plaintext = dsmPlaintext
}
encryptionIdentity := jid
if jid.Server == types.DefaultUserServer {
lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid)
if err != nil {
cli.Log.Warnf("Failed to get LID for %s: %v", jid, err)
} else if !lidForPN.IsEmpty() {
cli.migrateSessionStore(ctx, jid, lidForPN)
encryptionIdentity = lidForPN
}
}
jobs <- workItem{jid: jid, plaintext: plaintext, encryptionIdentity: encryptionIdentity}
}
close(jobs)
}()

var retryDevices []types.JID
var retryEncryptionIdentities []types.JID
for completed := 0; completed < len(allDevices); completed++ {
res := <-results
if res.err != nil {
if res.retry {
retryDevices = append(retryDevices, res.jid)
retryEncryptionIdentities = append(retryEncryptionIdentities, res.encIdent)
} else {
cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, res.jid, res.err)
}
continue
}
if res.node != nil {
participantNodes = append(participantNodes, *res.node)
if res.isPreKey {
includeIdentity = true
}
}
}

go func() { wg.Wait(); close(results) }()

participantNodes, includeIdentity = cli.retryEncryptMissing(ctx, id, msgPlaintext, dsmPlaintext, encAttrs, ownJID, ownLID, participantNodes, includeIdentity, retryDevices, retryEncryptionIdentities)
return participantNodes, includeIdentity
}

func (cli *Client) retryEncryptMissing(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it reasonable to abstract this logic outside of the function? 🤔

ctx context.Context,
id string,
msgPlaintext, dsmPlaintext []byte,
encAttrs waBinary.Attrs,
ownJID, ownLID types.JID,
participantNodes []waBinary.Node,
includeIdentity bool,
retryDevices, retryEncryptionIdentities []types.JID,
) ([]waBinary.Node, bool) {
if len(retryDevices) == 0 {
return participantNodes, includeIdentity
}
bundles, err := cli.fetchPreKeys(ctx, retryDevices)
if err != nil {
cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err)
return participantNodes, includeIdentity
}
for i, jid := range retryDevices {
resp := bundles[jid]
if resp.err != nil {
cli.Log.Warnf("Failed to fetch prekey for %s: %v", jid, resp.err)
continue
}
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
plaintext = dsmPlaintext
}
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(
ctx, plaintext, jid, retryEncryptionIdentities[i], resp.bundle, encAttrs,
)
if err != nil {
cli.Log.Warnf("Failed to encrypt %s for %s (retry): %v", id, jid, err)
continue
}
participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
}
}
return participantNodes, includeIdentity
Expand Down
Loading