Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ type Client struct {
// The library is currently embedded in mautrix-meta (https://github.com/mautrix/meta), but may be separated later.
MessengerConfig *MessengerConfig
RefreshCAT func(context.Context) error

// EncryptConcurrency controls how many goroutines are used when encrypting a
// message fan-out to multiple devices (e.g. large groups). If zero or
// negative, the library will pick a value automatically (runtime.NumCPU()).
// A value of 1 disables parallelization and preserves the previous
// sequential behavior. Small groups automatically fall back to sequential
// processing for lower overhead.
EncryptConcurrency int
}

type groupMetaCache struct {
Expand Down
220 changes: 167 additions & 53 deletions send.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"encoding/hex"
"errors"
"fmt"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/rs/zerolog"
Expand Down Expand Up @@ -1166,74 +1168,186 @@ func (cli *Client) encryptMessageForDevices(
ownLID := cli.getOwnLID()
includeIdentity := false
participantNodes := make([]waBinary.Node, 0, len(allDevices))
var retryDevices, retryEncryptionIdentities []types.JID
for _, jid := range allDevices {
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
if jid == ownJID || jid == ownLID {

// Heuristic: below this size, sequential loop is cheaper than goroutine scheduling.
const parallelThreshold = 8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 How was this number set? So, if I had 9 devices, you'd start a new routine for however many EncryptConsistency is (lets say I have 4 cpus).. Have you consider if that's faster than just letting the same 1 routine process just that one more device??

concurrency := cli.EncryptConcurrency
if concurrency <= 0 {
concurrency = runtime.NumCPU()
}
if concurrency < 1 {
concurrency = 1
}
Comment thread
jlucaso1 marked this conversation as resolved.
Outdated

if len(allDevices) < parallelThreshold || concurrency == 1 {
// Fall back to original sequential implementation for small batches
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to minimize the amount of duplicated code here?
Maybe abstract this into a function that can be routined or ran sequentially if threading is not possible??

var retryDevices, retryEncryptionIdentities []types.JID
for _, jid := range allDevices {
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
if jid == ownJID || jid == ownLID {
continue
}
plaintext = dsmPlaintext
}
encryptionIdentity := jid
if jid.Server == types.DefaultUserServer {
lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid)
if err != nil {
cli.Log.Warnf("Failed to get LID for %s: %v", jid, err)
} else if !lidForPN.IsEmpty() {
cli.migrateSessionStore(ctx, jid, lidForPN)
encryptionIdentity = lidForPN
}
}
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(
ctx, plaintext, jid, encryptionIdentity, nil, encAttrs,
)
if errors.Is(err, ErrNoSession) {
retryDevices = append(retryDevices, jid)
retryEncryptionIdentities = append(retryEncryptionIdentities, encryptionIdentity)
continue
} else if err != nil {
cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, jid, err)
continue
}
plaintext = dsmPlaintext
participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
}
}
encryptionIdentity := jid
if jid.Server == types.DefaultUserServer {
lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid)
if err != nil {
cli.Log.Warnf("Failed to get LID for %s: %v", jid, err)
} else if !lidForPN.IsEmpty() {
cli.migrateSessionStore(ctx, jid, lidForPN)
encryptionIdentity = lidForPN
participantNodes, includeIdentity = cli.retryEncryptMissing(ctx, id, msgPlaintext, dsmPlaintext, encAttrs, ownJID, ownLID, participantNodes, includeIdentity, retryDevices, retryEncryptionIdentities)
return participantNodes, includeIdentity
}

type workItem struct {
jid types.JID
plaintext []byte
encryptionIdentity types.JID
}
type resultItem struct {
jid types.JID
node *waBinary.Node
isPreKey bool
retry bool
err error
encIdent types.JID
}
jobs := make(chan workItem)
results := make(chan resultItem, len(allDevices))
var wg sync.WaitGroup

worker := func() {
defer wg.Done()
for wi := range jobs {
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(ctx, wi.plaintext, wi.jid, wi.encryptionIdentity, nil, encAttrs)
if errors.Is(err, ErrNoSession) {
results <- resultItem{jid: wi.jid, retry: true, err: err, encIdent: wi.encryptionIdentity}
continue
} else if err != nil {
results <- resultItem{jid: wi.jid, err: err}
continue
}
results <- resultItem{jid: wi.jid, node: encrypted, isPreKey: isPreKey}
}
}

for i := 0; i < concurrency; i++ {
wg.Add(1)
go worker()
}

go func() {
for _, jid := range allDevices {
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
if jid == ownJID || jid == ownLID {
continue
}
plaintext = dsmPlaintext
}
encryptionIdentity := jid
if jid.Server == types.DefaultUserServer {
lidForPN, err := cli.Store.LIDs.GetLIDForPN(ctx, jid)
if err != nil {
cli.Log.Warnf("Failed to get LID for %s: %v", jid, err)
} else if !lidForPN.IsEmpty() {
cli.migrateSessionStore(ctx, jid, lidForPN)
encryptionIdentity = lidForPN
}
}
jobs <- workItem{jid: jid, plaintext: plaintext, encryptionIdentity: encryptionIdentity}
}
close(jobs)
}()

var retryDevices []types.JID
var retryEncryptionIdentities []types.JID
for completed := 0; completed < len(allDevices); completed++ {
res := <-results
if res.err != nil {
if res.retry {
retryDevices = append(retryDevices, res.jid)
retryEncryptionIdentities = append(retryEncryptionIdentities, res.encIdent)
} else {
cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, res.jid, res.err)
}
continue
}
if res.node != nil {
participantNodes = append(participantNodes, *res.node)
if res.isPreKey {
includeIdentity = true
}
}
}

go func() { wg.Wait(); close(results) }()

participantNodes, includeIdentity = cli.retryEncryptMissing(ctx, id, msgPlaintext, dsmPlaintext, encAttrs, ownJID, ownLID, participantNodes, includeIdentity, retryDevices, retryEncryptionIdentities)
return participantNodes, includeIdentity
}

func (cli *Client) retryEncryptMissing(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it reasonable to abstract this logic outside of the function? 🤔

ctx context.Context,
id string,
msgPlaintext, dsmPlaintext []byte,
encAttrs waBinary.Attrs,
ownJID, ownLID types.JID,
participantNodes []waBinary.Node,
includeIdentity bool,
retryDevices, retryEncryptionIdentities []types.JID,
) ([]waBinary.Node, bool) {
if len(retryDevices) == 0 {
return participantNodes, includeIdentity
}
bundles, err := cli.fetchPreKeys(ctx, retryDevices)
if err != nil {
cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err)
return participantNodes, includeIdentity
}
for i, jid := range retryDevices {
resp := bundles[jid]
if resp.err != nil {
cli.Log.Warnf("Failed to fetch prekey for %s: %v", jid, resp.err)
continue
}
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
plaintext = dsmPlaintext
}
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(
ctx, plaintext, jid, encryptionIdentity, nil, encAttrs,
ctx, plaintext, jid, retryEncryptionIdentities[i], resp.bundle, encAttrs,
)
if errors.Is(err, ErrNoSession) {
retryDevices = append(retryDevices, jid)
retryEncryptionIdentities = append(retryEncryptionIdentities, encryptionIdentity)
continue
} else if err != nil {
// TODO return these errors if it's a fatal one (like context cancellation or database)
cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, jid, err)
if err != nil {
cli.Log.Warnf("Failed to encrypt %s for %s (retry): %v", id, jid, err)
continue
}

participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
}
}
if len(retryDevices) > 0 {
bundles, err := cli.fetchPreKeys(ctx, retryDevices)
if err != nil {
cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err)
} else {
for i, jid := range retryDevices {
resp := bundles[jid]
if resp.err != nil {
cli.Log.Warnf("Failed to fetch prekey for %s: %v", jid, resp.err)
continue
}
plaintext := msgPlaintext
if (jid.User == ownJID.User || jid.User == ownLID.User) && dsmPlaintext != nil {
plaintext = dsmPlaintext
}
encrypted, isPreKey, err := cli.encryptMessageForDeviceAndWrap(
ctx, plaintext, jid, retryEncryptionIdentities[i], resp.bundle, encAttrs,
)
if err != nil {
// TODO return these errors if it's a fatal one (like context cancellation or database)
cli.Log.Warnf("Failed to encrypt %s for %s (retry): %v", id, jid, err)
continue
}
participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
}
}
}
}
return participantNodes, includeIdentity
}

Expand Down