diff --git a/send.go b/send.go index b8b8b3225..20ea74730 100644 --- a/send.go +++ b/send.go @@ -1017,10 +1017,21 @@ func (cli *Client) makeDeviceIdentityNode() waBinary.Node { } } +// Reducing the number of I/Os to the DB, improving performance by len(allDevices)*time(I/O) which is heavily noticable on large groups +// one I/O to the db is 100ms on average, so sending in a group of 1000 people with at least 2 devices takes around 3 minutes to encrypt. +// now it should take less than 4 seconds on sqlite, Less than 1 second on MSSQL and POSTGRES func (cli *Client) encryptMessageForDevices(ctx context.Context, allDevices []types.JID, ownID types.JID, id string, msgPlaintext, dsmPlaintext []byte, encAttrs waBinary.Attrs) ([]waBinary.Node, bool) { includeIdentity := false participantNodes := make([]waBinary.Node, 0, len(allDevices)) var retryDevices []types.JID + //Cache all sessions and identity keys relative to this message to decrease query time + //This will reduce the number of queries to the db from 3*len(allDevices) -> 2 queries only, heavily improving performance + var addresses []string + for _, jid := range allDevices { + addresses = append(addresses, jid.SignalAddress().String()) + } + cli.Store.SessionsCache = cli.Store.Cache.GetSessions(addresses) + cli.Store.IdentityKeysCache = cli.Store.Cache.GetIdentityKeys(addresses) for _, jid := range allDevices { plaintext := msgPlaintext if jid.User == ownID.User && dsmPlaintext != nil { @@ -1044,6 +1055,13 @@ func (cli *Client) encryptMessageForDevices(ctx context.Context, allDevices []ty } } if len(retryDevices) > 0 { + //By Making the commands to the db be done in a separate go routine, + //A bug will appear when none of the devices in "allDevices" have any sessions in the db, + //causing the program to use the old approach to get the session and not finding any + //causing the program to crash + //by filling the cashe with at least one element, the prekeys collected in bundles will also be stored in cache, avoiding this issue + cli.Store.SessionsCache[ownID.SignalAddress().String()] = []byte{} + cli.Store.IdentityKeysCache[ownID.SignalAddress().String()] = [32]byte{} bundles, err := cli.fetchPreKeys(ctx, retryDevices) if err != nil { cli.Log.Warnf("Failed to fetch prekeys for %v to retry encryption: %v", retryDevices, err) @@ -1069,7 +1087,21 @@ func (cli *Client) encryptMessageForDevices(ctx context.Context, allDevices []ty } } } + //Remove the dummy key + delete(cli.Store.SessionsCache, ownID.SignalAddress().String()) + delete(cli.Store.IdentityKeysCache, ownID.SignalAddress().String()) + } + //Store All Sessions at once. This decreases the number of commands to the database from 2*len(allDevices) -> 2 commands only + //An alternate, faster method would be using "go StoreSessions" which is incompatible with file databases such as sqlite + if len(cli.Store.SessionsCache) > 0 { + cli.Store.Cache.StoreSessions(cli.Store.SessionsCache) + } + if len(cli.Store.IdentityKeysCache) > 0 { + cli.Store.Cache.StoreIdentityKeys(cli.Store.IdentityKeysCache) } + //clear the cache once the encryption is done to release memory + clear(cli.Store.SessionsCache) + clear(cli.Store.IdentityKeysCache) return participantNodes, includeIdentity } diff --git a/store/signal.go b/store/signal.go index 96cb2b361..deea5f67f 100644 --- a/store/signal.go +++ b/store/signal.go @@ -34,6 +34,11 @@ func (device *Device) GetLocalRegistrationId() uint32 { } func (device *Device) SaveIdentity(address *protocol.SignalAddress, identityKey *identity.Key) { + //If device has a Cache, store it there to insert them all at once later + if device.IdentityKeysCache != nil && len(device.IdentityKeysCache) > 0 { + device.IdentityKeysCache[address.String()] = identityKey.PublicKey().PublicKey() + return + } for i := 0; ; i++ { err := device.Identities.PutIdentity(address.String(), identityKey.PublicKey().PublicKey()) if err == nil || !device.handleDatabaseError(i, err, "save identity of %s", address.String()) { @@ -43,6 +48,13 @@ func (device *Device) SaveIdentity(address *protocol.SignalAddress, identityKey } func (device *Device) IsTrustedIdentity(address *protocol.SignalAddress, identityKey *identity.Key) bool { + //Check if device has a Cache. If not use the default method + if device.IdentityKeysCache != nil && len(device.IdentityKeysCache) > 0 { + if cache, ok := device.IdentityKeysCache[address.String()]; ok { + return cache == identityKey.PublicKey().PublicKey() + } + return true + } for i := 0; ; i++ { isTrusted, err := device.Identities.IsTrustedIdentity(address.String(), identityKey.PublicKey().PublicKey()) if err == nil || !device.handleDatabaseError(i, err, "check if %s's identity is trusted", address.String()) { @@ -87,6 +99,18 @@ func (device *Device) ContainsPreKey(preKeyID uint32) bool { } func (device *Device) LoadSession(address *protocol.SignalAddress) *record.Session { + //Check if device has a Cache. If not use the default method + if device.SessionsCache != nil && len(device.SessionsCache) > 0 { + if rawSess, ok := device.SessionsCache[address.String()]; ok { + sess, err := record.NewSessionFromBytes(rawSess, SignalProtobufSerializer.Session, SignalProtobufSerializer.State) + if err != nil { + device.Log.Errorf("Failed to deserialize session with %s: %v", address.String(), err) + return record.NewSession(SignalProtobufSerializer.Session, SignalProtobufSerializer.State) + } + return sess + } + return record.NewSession(SignalProtobufSerializer.Session, SignalProtobufSerializer.State) + } var rawSess []byte var err error for i := 0; ; i++ { @@ -114,6 +138,11 @@ func (device *Device) GetSubDeviceSessions(name string) []uint32 { } func (device *Device) StoreSession(address *protocol.SignalAddress, record *record.Session) { + //If device has a Cache, store it there to insert them all at once later + if device.SessionsCache != nil && len(device.SessionsCache) > 0 { + device.SessionsCache[address.String()] = record.Serialize() + return + } for i := 0; ; i++ { err := device.Sessions.PutSession(address.String(), record.Serialize()) if err == nil || !device.handleDatabaseError(i, err, "store session with %s", address.String()) { @@ -123,6 +152,13 @@ func (device *Device) StoreSession(address *protocol.SignalAddress, record *reco } func (device *Device) ContainsSession(remoteAddress *protocol.SignalAddress) bool { + //Check if device has a Cache. If not use the default method + if device.SessionsCache != nil && len(device.SessionsCache) > 0 { + if _, ok := device.SessionsCache[remoteAddress.String()]; ok { + return true + } + return false + } for i := 0; ; i++ { hasSession, err := device.Sessions.HasSession(remoteAddress.String()) if err == nil || !device.handleDatabaseError(i, err, "check if store has session for %s", remoteAddress.String()) { diff --git a/store/sqlstore/container.go b/store/sqlstore/container.go index d49adcdf2..a818c7c29 100644 --- a/store/sqlstore/container.go +++ b/store/sqlstore/container.go @@ -136,6 +136,7 @@ func (c *Container) scanDevice(row scannable) (*store.Device, error) { device.ChatSettings = innerStore device.MsgSecrets = innerStore device.PrivacyTokens = innerStore + device.Cache = innerStore device.Container = c device.Initialized = true @@ -258,6 +259,7 @@ func (c *Container) PutDevice(device *store.Device) error { device.ChatSettings = innerStore device.MsgSecrets = innerStore device.PrivacyTokens = innerStore + device.Cache = innerStore device.Initialized = true } return err diff --git a/store/sqlstore/store.go b/store/sqlstore/store.go index a575de0d0..59693bb93 100644 --- a/store/sqlstore/store.go +++ b/store/sqlstore/store.go @@ -771,3 +771,112 @@ func (s *SQLStore) GetPrivacyToken(user types.JID) (*store.PrivacyToken, error) return &token, nil } } + +const ( + getCacheSessionsQuery = `SELECT their_id, session FROM whatsmeow_sessions WHERE our_jid=$1 AND their_id IN ` + getCacheIdentityKeysQuery = `SELECT their_id, identity_info FROM whatsmeow_identity_keys WHERE our_jid=$1 AND their_id IN ` + storeCacheSessionsQuery = ` + INSERT INTO whatsmeow_sessions (our_jid, their_id, session) VALUES %s + ON CONFLICT (our_jid, their_id) DO UPDATE SET session=excluded.session + ` + storeCacheIdentityKeysQuery = ` + INSERT INTO whatsmeow_identity_keys (our_jid, their_id, identity_info) VALUES %s + ON CONFLICT (our_jid, their_id) DO UPDATE SET identity_info=excluded.identity_info + ` +) + +func (s *SQLStore) GetSessions(addresses []string) (final map[string][]byte) { + query := getCacheSessionsQuery + "(" + queryParams := make([]interface{}, len(addresses)+1) + queryParams[0] = s.JID + final = make(map[string][]byte) + for index, address := range addresses { + if index > 0 { + query += "," + } + query += fmt.Sprintf("$%d", index+2) + queryParams[index+1] = address + } + query += ")" + rows, err := s.db.Query(query, queryParams...) + if err != nil { + s.log.Errorf(err.Error()) + return + } + for rows.Next() { + var session []byte + var id string + rows.Scan(&id, &session) + final[id] = session + } + return +} + +func (s *SQLStore) GetIdentityKeys(addresses []string) (final map[string][32]byte) { + query := getCacheIdentityKeysQuery + "(" + queryParams := make([]interface{}, len(addresses)+1) + queryParams[0] = s.JID + final = make(map[string][32]byte) + for index, address := range addresses { + if index > 0 { + query += "," + } + query += fmt.Sprintf("$%d", index+2) + queryParams[index+1] = address + } + query += ")" + rows, err := s.db.Query(query, queryParams...) + if err != nil { + s.log.Errorf(err.Error()) + return + } + for rows.Next() { + var session []byte + var id string + rows.Scan(&id, &session) + final[id] = *(*[32]byte)(session) + } + return +} + +// This Could be better implemented with bulk insert approach +func (s *SQLStore) StoreSessions(sessions map[string][]byte) error { + queryValues := "" + queryParams := make([]interface{}, len(sessions)*3) + cnt := 0 + for address, session := range sessions { + if len(queryValues) > 0 { + queryValues += "," + } + counter := cnt * 3 + queryValues += fmt.Sprintf("($%d, $%d, $%d)", counter+1, counter+2, counter+3) + queryParams[counter+1] = s.JID + queryParams[counter+2] = address + queryParams[counter+3] = session[:] + cnt++ + } + query := fmt.Sprintf(storeCacheSessionsQuery, queryValues) + _, err := s.db.Exec(query, queryParams...) + return err +} + +// This Could be better implemented with bulk insert approach +func (s *SQLStore) StoreIdentityKeys(identityKeys map[string][32]byte) error { + queryValues := "" + queryParams := make([]interface{}, len(identityKeys)*3) + cnt := 0 + for address, key := range identityKeys { + if len(queryValues) > 0 { + queryValues += "," + } + counter := cnt * 3 + queryValues += fmt.Sprintf("($%d, $%d, $%d)", counter+1, counter+2, counter+3) + queryParams[counter+1] = s.JID + queryParams[counter+2] = address + queryParams[counter+3] = key[:] + cnt++ + } + query := fmt.Sprintf(storeCacheIdentityKeysQuery, queryValues) + _, err := s.db.Exec(query, queryParams...) + return err +} diff --git a/store/store.go b/store/store.go index 85bf1c57f..2fc4d8ed8 100644 --- a/store/store.go +++ b/store/store.go @@ -126,6 +126,12 @@ type PrivacyTokenStore interface { GetPrivacyToken(user types.JID) (*PrivacyToken, error) } +type CacheStore interface { + GetSessions(addresses []string) map[string][]byte + GetIdentityKeys(addresses []string) map[string][32]byte + StoreSessions(sessions map[string][]byte) error + StoreIdentityKeys(identityKeys map[string][32]byte) error +} type AllStores interface { IdentityStore SessionStore @@ -168,9 +174,14 @@ type Device struct { ChatSettings ChatSettingsStore MsgSecrets MsgSecretStore PrivacyTokens PrivacyTokenStore + Cache CacheStore Container DeviceContainer DatabaseErrorHandler func(device *Device, action string, attemptIndex int, err error) (retry bool) + + //Cache to Temporary save sessions and identity keys for faster group send + SessionsCache map[string][]byte + IdentityKeysCache map[string][32]byte } func (device *Device) handleDatabaseError(attemptIndex int, err error, action string, args ...interface{}) bool {