Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions send.go
Original file line number Diff line number Diff line change
Expand Up @@ -1014,10 +1014,21 @@ func (cli *Client) makeDeviceIdentityNode() waBinary.Node {
}
}

// Reducing the number of I/Os to the DB, improving performance by len(allDevices)*time(I/O) which is heavily noticable on large groups
// one I/O to the db is 100ms on average, so sending in a group of 1000 people with at least 2 devices takes around 3 minutes to encrypt.
// now it should take less than 4 seconds on sqlite, Less than 1 second on MSSQL and POSTGRES
func (cli *Client) encryptMessageForDevices(ctx context.Context, allDevices []types.JID, ownID types.JID, id string, msgPlaintext, dsmPlaintext []byte, encAttrs waBinary.Attrs) ([]waBinary.Node, bool) {
includeIdentity := false
participantNodes := make([]waBinary.Node, 0, len(allDevices))
var retryDevices []types.JID
//Cache all sessions and identity keys relative to this message to decrease query time
//This will reduce the number of queries to the db from 3*len(allDevices) -> 2 queries only, heavily improving performance
var addresses []string
for _, jid := range allDevices {
addresses = append(addresses, jid.SignalAddress().String())
}
cli.Store.SessionsCache = cli.Store.Cache.GetSessions(addresses)
cli.Store.IdentityKeysCache = cli.Store.Cache.GetIdentityKeys(addresses)
for _, jid := range allDevices {
plaintext := msgPlaintext
if jid.User == ownID.User && dsmPlaintext != nil {
Expand All @@ -1041,6 +1052,13 @@ func (cli *Client) encryptMessageForDevices(ctx context.Context, allDevices []ty
}
}
if len(retryDevices) > 0 {
//By Making the commands to the db be done in a separate go routine,
//A bug will appear when none of the devices in "allDevices" have any sessions in the db,
//causing the program to use the old approach to get the session and not finding any
//causing the program to crash
//by filling the cashe with at least one element, the prekeys collected in bundles will also be stored in cache, avoiding this issue
cli.Store.SessionsCache[ownID.SignalAddress().String()] = []byte{}
cli.Store.IdentityKeysCache[ownID.SignalAddress().String()] = [32]byte{}
bundles, err := cli.fetchPreKeys(ctx, retryDevices)
if err != nil {
cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err)
Expand All @@ -1066,7 +1084,21 @@ func (cli *Client) encryptMessageForDevices(ctx context.Context, allDevices []ty
}
}
}
//Remove the dummy key
delete(cli.Store.SessionsCache, ownID.SignalAddress().String())
delete(cli.Store.IdentityKeysCache, ownID.SignalAddress().String())
}
//Store All Sessions at once. This decreases the number of commands to the database from 2*len(allDevices) -> 2 commands only
//An alternate, faster method would be using "go StoreSessions" which is incompatible with file databases such as sqlite
if len(cli.Store.SessionsCache) > 0 {
cli.Store.Cache.StoreSessions(cli.Store.SessionsCache)
}
if len(cli.Store.IdentityKeysCache) > 0 {
cli.Store.Cache.StoreIdentityKeys(cli.Store.IdentityKeysCache)
}
//clear the cache once the encryption is done to release memory
clear(cli.Store.SessionsCache)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting that we clear cache after every call - why does it need to be available global then?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the I/O to the db is being used inside LoadSession and IsTrustedIdentity in store/signal.go, the check had to be made whether we already have it or not. Since both of these functions are whatsmeows implementation of another library, this was the only suitable place I could find to place it in

clear(cli.Store.IdentityKeysCache)
return participantNodes, includeIdentity
}

Expand Down
36 changes: 36 additions & 0 deletions store/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func (device *Device) GetLocalRegistrationId() uint32 {
}

func (device *Device) SaveIdentity(address *protocol.SignalAddress, identityKey *identity.Key) {
//If device has a Cache, store it there to insert them all at once later
if device.IdentityKeysCache != nil && len(device.IdentityKeysCache) > 0 {
device.IdentityKeysCache[address.String()] = identityKey.PublicKey().PublicKey()
return
}
for i := 0; ; i++ {
err := device.Identities.PutIdentity(address.String(), identityKey.PublicKey().PublicKey())
if err == nil || !device.handleDatabaseError(i, err, "save identity of %s", address.String()) {
Expand All @@ -43,6 +48,13 @@ func (device *Device) SaveIdentity(address *protocol.SignalAddress, identityKey
}

func (device *Device) IsTrustedIdentity(address *protocol.SignalAddress, identityKey *identity.Key) bool {
//Check if device has a Cache. If not use the default method
if device.IdentityKeysCache != nil && len(device.IdentityKeysCache) > 0 {
if cache, ok := device.IdentityKeysCache[address.String()]; ok {
return cache == identityKey.PublicKey().PublicKey()
}
return true
}
for i := 0; ; i++ {
isTrusted, err := device.Identities.IsTrustedIdentity(address.String(), identityKey.PublicKey().PublicKey())
if err == nil || !device.handleDatabaseError(i, err, "check if %s's identity is trusted", address.String()) {
Expand Down Expand Up @@ -87,6 +99,18 @@ func (device *Device) ContainsPreKey(preKeyID uint32) bool {
}

func (device *Device) LoadSession(address *protocol.SignalAddress) *record.Session {
//Check if device has a Cache. If not use the default method
if device.SessionsCache != nil && len(device.SessionsCache) > 0 {
if rawSess, ok := device.SessionsCache[address.String()]; ok {
sess, err := record.NewSessionFromBytes(rawSess, SignalProtobufSerializer.Session, SignalProtobufSerializer.State)
if err != nil {
device.Log.Errorf("Failed to deserialize session with %s: %v", address.String(), err)
return record.NewSession(SignalProtobufSerializer.Session, SignalProtobufSerializer.State)
}
return sess
}
return record.NewSession(SignalProtobufSerializer.Session, SignalProtobufSerializer.State)
}
var rawSess []byte
for i := 0; ; i++ {
var err error
Expand All @@ -111,6 +135,11 @@ func (device *Device) GetSubDeviceSessions(name string) []uint32 {
}

func (device *Device) StoreSession(address *protocol.SignalAddress, record *record.Session) {
//If device has a Cache, store it there to insert them all at once later
if device.SessionsCache != nil && len(device.SessionsCache) > 0 {
device.SessionsCache[address.String()] = record.Serialize()
return
}
for i := 0; ; i++ {
err := device.Sessions.PutSession(address.String(), record.Serialize())
if err == nil || !device.handleDatabaseError(i, err, "store session with %s", address.String()) {
Expand All @@ -120,6 +149,13 @@ func (device *Device) StoreSession(address *protocol.SignalAddress, record *reco
}

func (device *Device) ContainsSession(remoteAddress *protocol.SignalAddress) bool {
//Check if device has a Cache. If not use the default method
if device.SessionsCache != nil && len(device.SessionsCache) > 0 {
if _, ok := device.SessionsCache[remoteAddress.String()]; ok {
return true
}
return false
}
for i := 0; ; i++ {
hasSession, err := device.Sessions.HasSession(remoteAddress.String())
if err == nil || !device.handleDatabaseError(i, err, "store has session for %s", remoteAddress.String()) {
Expand Down
2 changes: 2 additions & 0 deletions store/sqlstore/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func (c *Container) scanDevice(row scannable) (*store.Device, error) {
device.ChatSettings = innerStore
device.MsgSecrets = innerStore
device.PrivacyTokens = innerStore
device.Cache = innerStore
device.Container = c
device.Initialized = true

Expand Down Expand Up @@ -255,6 +256,7 @@ func (c *Container) PutDevice(device *store.Device) error {
device.ChatSettings = innerStore
device.MsgSecrets = innerStore
device.PrivacyTokens = innerStore
device.Cache = innerStore
device.Initialized = true
}
return err
Expand Down
109 changes: 109 additions & 0 deletions store/sqlstore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,112 @@ func (s *SQLStore) GetPrivacyToken(user types.JID) (*store.PrivacyToken, error)
return &token, nil
}
}

const (
getCacheSessionsQuery = `SELECT their_id, session FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id IN `
getCacheIdentityKeysQuery = `SELECT their_id, identity_info FROM whatsmeow_identity_keys WHERE our_jid=$1 AND their_id IN `
storeCacheSessionsQuery = `
INSERT INTO whatsmeow_sessions (our_jid, their_id, session) VALUES %s
ON CONFLICT (our_jid, their_id) DO UPDATE SET session=excluded.session
`
storeCacheIdentityKeysQuery = `
INSERT INTO whatsmeow_identity_keys (our_jid, their_id, identity_info) VALUES %s
ON CONFLICT (our_jid, their_id) DO UPDATE SET identity_info=excluded.identity_info
`
)

func (s *SQLStore) GetSessions(addresses []string) (final map[string][]byte) {
query := getCacheSessionsQuery + "("
queryParams := make([]interface{}, len(addresses)+1)
queryParams[0] = s.JID
final = make(map[string][]byte)
for index, address := range addresses {
if index > 0 {
query += ","
}
query += fmt.Sprintf("$%d", index+2)
queryParams[index+1] = address
}
query += ")"
rows, err := s.db.Query(query, queryParams...)
if err != nil {
s.log.Errorf(err.Error())
return
}
for rows.Next() {
var session []byte
var id string
rows.Scan(&id, &session)
final[id] = session
}
return
}

func (s *SQLStore) GetIdentityKeys(addresses []string) (final map[string][32]byte) {
query := getCacheIdentityKeysQuery + "("
queryParams := make([]interface{}, len(addresses)+1)
queryParams[0] = s.JID
final = make(map[string][32]byte)
for index, address := range addresses {
if index > 0 {
query += ","
}
query += fmt.Sprintf("$%d", index+2)
queryParams[index+1] = address
}
query += ")"
rows, err := s.db.Query(query, queryParams...)
if err != nil {
s.log.Errorf(err.Error())
return
}
for rows.Next() {
var session []byte
var id string
rows.Scan(&id, &session)
final[id] = *(*[32]byte)(session)
}
return
}

// This Could be better implemented with bulk insert approach
func (s *SQLStore) StoreSessions(sessions map[string][]byte) error {
queryValues := ""
queryParams := make([]interface{}, len(sessions)*3)
cnt := 0
for address, session := range sessions {
if len(queryValues) > 0 {
queryValues += ","
}
counter := cnt * 3
queryValues += fmt.Sprintf("($%d, $%d, $%d)", counter+1, counter+2, counter+3)
queryParams[counter+1] = s.JID
queryParams[counter+2] = address
queryParams[counter+3] = session[:]
cnt++
}
query := fmt.Sprintf(storeCacheSessionsQuery, queryValues)
_, err := s.db.Exec(query, queryParams...)
return err
}

// This Could be better implemented with bulk insert approach
func (s *SQLStore) StoreIdentityKeys(identityKeys map[string][32]byte) error {
queryValues := ""
queryParams := make([]interface{}, len(identityKeys)*3)
cnt := 0
for address, key := range identityKeys {
if len(queryValues) > 0 {
queryValues += ","
}
counter := cnt * 3
queryValues += fmt.Sprintf("($%d, $%d, $%d)", counter+1, counter+2, counter+3)
queryParams[counter+1] = s.JID
queryParams[counter+2] = address
queryParams[counter+3] = key[:]
cnt++
}
query := fmt.Sprintf(storeCacheIdentityKeysQuery, queryValues)
_, err := s.db.Exec(query, queryParams...)
return err
}
12 changes: 12 additions & 0 deletions store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ type PrivacyTokenStore interface {
GetPrivacyToken(user types.JID) (*PrivacyToken, error)
}

type CacheStore interface {
GetSessions(addresses []string) map[string][]byte
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using primitives (string/byte) we should try using already defined types, like types.JID (likely there's something for Session/IdentityKeys as well)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IdentityKeys and Session types are both []byte in the code which is the reason I chose that

GetIdentityKeys(addresses []string) map[string][32]byte
StoreSessions(sessions map[string][]byte) error
StoreIdentityKeys(identityKeys map[string][32]byte) error
}

type Device struct {
Log waLog.Logger

Expand Down Expand Up @@ -154,9 +161,14 @@ type Device struct {
ChatSettings ChatSettingsStore
MsgSecrets MsgSecretStore
PrivacyTokens PrivacyTokenStore
Cache CacheStore
Container DeviceContainer

DatabaseErrorHandler func(device *Device, action string, attemptIndex int, err error) (retry bool)

//Cache to Temporary save sessions and identity keys for faster group send
SessionsCache map[string][]byte
IdentityKeysCache map[string][32]byte
Comment on lines +177 to +184
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like too many Caches here, I lost the idea when we use SessionCache, when we use IdentityKeysCache and when just Cache.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello
Sorry for the delay, I didn't check my github in a while
So the idea of CacheStore is to a have a way that implements the method to get all the required IdentityKeys and Sessions at once
SessionsCache and IdentityCache save the result of said query. While doing it on redis will certainly improve speed, I would be forcing the implementation of the library to be dependent on that DB

}

func (device *Device) handleDatabaseError(attemptIndex int, err error, action string, args ...interface{}) bool {
Expand Down