diff --git a/src/Socket/messages-send.ts b/src/Socket/messages-send.ts index c6112fb..37c51d7 100644 --- a/src/Socket/messages-send.ts +++ b/src/Socket/messages-send.ts @@ -234,6 +234,14 @@ export const makeMessagesSocket = (config: SocketConfig) => { const createParticipantNodes = async(jids: string[], bytes: Buffer) => { await assertSessions(jids, false) + + if(authState.keys.isInTransaction()) { + await authState.keys.prefetch( + 'session', + jids.map(jid => jidToSignalProtocolAddress(jid).toString()) + ) + } + const nodes = await Promise.all( jids.map( async jid => { @@ -278,133 +286,137 @@ export const makeMessagesSocket = (config: SocketConfig) => { devices.push({ user, device }) } - if(isGroup) { - const { ciphertext, senderKeyDistributionMessageKey } = await encryptSenderKeyMsgSignalProto(destinationJid, encodedMsg, meId, authState) - - const [groupData, senderKeyMap] = await Promise.all([ - (async() => { - let groupData = cachedGroupMetadata ? await cachedGroupMetadata(jid) : undefined - if(!groupData) groupData = await groupMetadata(jid) - return groupData - })(), - (async() => { - const result = await authState.keys.get('sender-key-memory', [jid]) - return result[jid] || { } - })() - ]) - - if(!participant) { - const participantsList = groupData.participants.map(p => p.id) - const additionalDevices = await getUSyncDevices(participantsList, false) - devices.push(...additionalDevices) - } - - const senderKeyJids: string[] = [] - // ensure a connection is established with every device - for(const {user, device} of devices) { - const jid = jidEncode(user, 's.whatsapp.net', device) - if(!senderKeyMap[jid]) { - senderKeyJids.push(jid) - // store that this person has had the sender keys sent to them - senderKeyMap[jid] = true - } - } - // if there are some participants with whom the session has not been established - // if there are, we re-send the senderkey - if(senderKeyJids.length) { - logger.debug({ senderKeyJids }, 'sending new sender key') - - const encSenderKeyMsg = encodeWAMessage({ - senderKeyDistributionMessage: { - axolotlSenderKeyDistributionMessage: senderKeyDistributionMessageKey, - groupId: destinationJid - } - }) - - participants.push( - ...(await createParticipantNodes(senderKeyJids, encSenderKeyMsg)) - ) - } - - binaryNodeContent.push({ - tag: 'enc', - attrs: { v: '2', type: 'skmsg' }, - content: ciphertext - }) - - await authState.keys.set({ 'sender-key-memory': { [jid]: senderKeyMap } }) - } else { - const { user: meUser } = jidDecode(meId) - - const encodedMeMsg = encodeWAMessage({ - deviceSentMessage: { - destinationJid, - message - } - }) - - if(!participant) { - devices.push({ user }) - devices.push({ user: meUser }) - - const additionalDevices = await getUSyncDevices([ meId, jid ], true) - devices.push(...additionalDevices) - } - - const meJids: string[] = [] - const otherJids: string[] = [] - for(const { user, device } of devices) { - const jid = jidEncode(user, 's.whatsapp.net', device) - const isMe = user === meUser - if(isMe) meJids.push(jid) - else otherJids.push(jid) - } - - const [meNodes, otherNodes] = await Promise.all([ - createParticipantNodes(meJids, encodedMeMsg), - createParticipantNodes(otherJids, encodedMsg) - ]) - participants.push(...meNodes) - participants.push(...otherNodes) - } - - if(participants.length) { - binaryNodeContent.push({ - tag: 'participants', - attrs: { }, - content: participants - }) - } - - const stanza: BinaryNode = { - tag: 'message', - attrs: { - id: msgId, - type: 'text', - to: destinationJid, - ...(additionalAttributes || {}) - }, - content: binaryNodeContent - } + await authState.keys.transaction( + async() => { + if(isGroup) { + const { ciphertext, senderKeyDistributionMessageKey } = await encryptSenderKeyMsgSignalProto(destinationJid, encodedMsg, meId, authState) + + const [groupData, senderKeyMap] = await Promise.all([ + (async() => { + let groupData = cachedGroupMetadata ? await cachedGroupMetadata(jid) : undefined + if(!groupData) groupData = await groupMetadata(jid) + return groupData + })(), + (async() => { + const result = await authState.keys.get('sender-key-memory', [jid]) + return result[jid] || { } + })() + ]) - const shouldHaveIdentity = !!participants.find( - participant => (participant.content! as BinaryNode[]).find(n => n.attrs.type === 'pkmsg') + if(!participant) { + const participantsList = groupData.participants.map(p => p.id) + const additionalDevices = await getUSyncDevices(participantsList, false) + devices.push(...additionalDevices) + } + + const senderKeyJids: string[] = [] + // ensure a connection is established with every device + for(const {user, device} of devices) { + const jid = jidEncode(user, 's.whatsapp.net', device) + if(!senderKeyMap[jid]) { + senderKeyJids.push(jid) + // store that this person has had the sender keys sent to them + senderKeyMap[jid] = true + } + } + // if there are some participants with whom the session has not been established + // if there are, we re-send the senderkey + if(senderKeyJids.length) { + logger.debug({ senderKeyJids }, 'sending new sender key') + + const encSenderKeyMsg = encodeWAMessage({ + senderKeyDistributionMessage: { + axolotlSenderKeyDistributionMessage: senderKeyDistributionMessageKey, + groupId: destinationJid + } + }) + + participants.push( + ...(await createParticipantNodes(senderKeyJids, encSenderKeyMsg)) + ) + } + + binaryNodeContent.push({ + tag: 'enc', + attrs: { v: '2', type: 'skmsg' }, + content: ciphertext + }) + + await authState.keys.set({ 'sender-key-memory': { [jid]: senderKeyMap } }) + } else { + const { user: meUser } = jidDecode(meId) + + const encodedMeMsg = encodeWAMessage({ + deviceSentMessage: { + destinationJid, + message + } + }) + + if(!participant) { + devices.push({ user }) + devices.push({ user: meUser }) + + const additionalDevices = await getUSyncDevices([ meId, jid ], true) + devices.push(...additionalDevices) + } + + const meJids: string[] = [] + const otherJids: string[] = [] + for(const { user, device } of devices) { + const jid = jidEncode(user, 's.whatsapp.net', device) + const isMe = user === meUser + if(isMe) meJids.push(jid) + else otherJids.push(jid) + } + + const [meNodes, otherNodes] = await Promise.all([ + createParticipantNodes(meJids, encodedMeMsg), + createParticipantNodes(otherJids, encodedMsg) + ]) + participants.push(...meNodes) + participants.push(...otherNodes) + } + + if(participants.length) { + binaryNodeContent.push({ + tag: 'participants', + attrs: { }, + content: participants + }) + } + + const stanza: BinaryNode = { + tag: 'message', + attrs: { + id: msgId, + type: 'text', + to: destinationJid, + ...(additionalAttributes || {}) + }, + content: binaryNodeContent + } + + const shouldHaveIdentity = !!participants.find( + participant => (participant.content! as BinaryNode[]).find(n => n.attrs.type === 'pkmsg') + ) + + if(shouldHaveIdentity) { + (stanza.content as BinaryNode[]).push({ + tag: 'device-identity', + attrs: { }, + content: proto.ADVSignedDeviceIdentity.encode(authState.creds.account).finish() + }) + + logger.debug({ jid }, 'adding device identity') + } + + logger.debug({ msgId }, `sending message to ${participants.length} devices`) + + await sendNode(stanza) + } ) - if(shouldHaveIdentity) { - (stanza.content as BinaryNode[]).push({ - tag: 'device-identity', - attrs: { }, - content: proto.ADVSignedDeviceIdentity.encode(authState.creds.account).finish() - }) - - logger.debug({ jid }, 'adding device identity') - } - - logger.debug({ msgId }, `sending message to ${participants.length} devices`) - - await sendNode(stanza) - return msgId } diff --git a/src/Socket/socket.ts b/src/Socket/socket.ts index 102cdcc..c09459e 100644 --- a/src/Socket/socket.ts +++ b/src/Socket/socket.ts @@ -5,7 +5,7 @@ import WebSocket from "ws" import { randomBytes } from 'crypto' import { proto } from '../../WAProto' import { DisconnectReason, SocketConfig, BaileysEventEmitter, ConnectionState, AuthenticationCreds } from "../Types" -import { Curve, generateRegistrationNode, configureSuccessfulPairing, generateLoginNode, encodeBigEndian, promiseTimeout, generateOrGetPreKeys, xmppSignedPreKey, xmppPreKey, getPreKeys, makeNoiseHandler, useSingleFileAuthState } from "../Utils" +import { Curve, generateRegistrationNode, configureSuccessfulPairing, generateLoginNode, encodeBigEndian, promiseTimeout, generateOrGetPreKeys, xmppSignedPreKey, xmppPreKey, getPreKeys, makeNoiseHandler, useSingleFileAuthState, addTransactionCapability } from "../Utils" import { DEFAULT_ORIGIN, DEF_TAG_PREFIX, DEF_CALLBACK_PREFIX, KEY_BUNDLE_TYPE } from "../Defaults" import { assertNodeErrorFree, BinaryNode, encodeBinaryNode, S_WHATSAPP_NET, getBinaryNodeChild } from '../WABinary' @@ -539,7 +539,11 @@ export const makeSocket = ({ return { ws, ev, - authState, + authState: { + creds, + // add capability + keys: addTransactionCapability(authState.keys, logger) + }, get user () { return authState.creds.me }, diff --git a/src/Types/Auth.ts b/src/Types/Auth.ts index c2cae16..1e9eedd 100644 --- a/src/Types/Auth.ts +++ b/src/Types/Auth.ts @@ -52,27 +52,19 @@ export type SignalDataTypeMap = { 'app-state-sync-version': LTHashState } -type SignalDataSet = { [T in keyof SignalDataTypeMap]?: { [id: string]: SignalDataTypeMap[T] | null } } +export type SignalDataSet = { [T in keyof SignalDataTypeMap]?: { [id: string]: SignalDataTypeMap[T] | null } } type Awaitable = T | Promise + export type SignalKeyStore = { get(type: T, ids: string[]): Awaitable<{ [id: string]: SignalDataTypeMap[T] }> set(data: SignalDataSet): Awaitable +} - /*getPreKey: (keyId: number) => Awaitable - setPreKey: (keyId: number, pair: KeyPair | null) => Awaitable - - getSession: (sessionId: string) => Awaitable - setSession: (sessionId: string, item: any | null) => Awaitable - - getSenderKey: (id: string) => Awaitable - setSenderKey: (id: string, item: any | null) => Awaitable - - getAppStateSyncKey: (id: string) => Awaitable - setAppStateSyncKey: (id: string, item: proto.IAppStateSyncKeyData | null) => Awaitable - - getAppStateSyncVersion: (name: WAPatchName) => Awaitable - setAppStateSyncVersion: (id: WAPatchName, item: LTHashState) => Awaitable*/ +export type SignalKeyStoreWithTransaction = SignalKeyStore & { + isInTransaction: () => boolean + transaction(exec: () => Promise): Promise + prefetch(type: T, ids: string[]): Promise } export type SignalAuthState = { diff --git a/src/Utils/auth-utils.ts b/src/Utils/auth-utils.ts index cf40474..39ea354 100644 --- a/src/Utils/auth-utils.ts +++ b/src/Utils/auth-utils.ts @@ -1,8 +1,97 @@ +import { Boom } from '@hapi/boom' import { randomBytes } from 'crypto' -import type { AuthenticationCreds, AuthenticationState, SignalDataTypeMap } from "../Types" +import type { Logger } from 'pino' +import type { AuthenticationCreds, AuthenticationState, SignalDataTypeMap, SignalDataSet, SignalKeyStore, SignalKeyStoreWithTransaction } from "../Types" import { Curve, signedKeyPair } from './crypto' import { generateRegistrationId, BufferJSON } from './generics' +const KEY_MAP: { [T in keyof SignalDataTypeMap]: string } = { + 'pre-key': 'preKeys', + 'session': 'sessions', + 'sender-key': 'senderKeys', + 'app-state-sync-key': 'appStateSyncKeys', + 'app-state-sync-version': 'appStateVersions', + 'sender-key-memory': 'senderKeyMemory' +} + +export const addTransactionCapability = (state: SignalKeyStore, logger: Logger): SignalKeyStoreWithTransaction => { + let inTransaction = false + let transactionCache: SignalDataSet = { } + let mutations: SignalDataSet = { } + + const prefetch = async(type: keyof SignalDataTypeMap, ids: string[]) => { + if(!inTransaction) { + throw new Boom('Cannot prefetch without transaction') + } + + const dict = transactionCache[type] + const idsRequiringFetch = ids.filter(item => !dict?.[item]) + const result = await state.get(type, idsRequiringFetch) + + transactionCache[type] = transactionCache[type] || { } + Object.assign(transactionCache[type], result) + } + + return { + get: async(type, ids) => { + if(inTransaction) { + await prefetch(type, ids) + return ids.reduce( + (dict, id) => { + const value = transactionCache[type]?.[id] + if(value) { + dict[id] = value + } + return dict + }, { } + ) + } else { + return state.get(type, ids) + } + }, + set: data => { + if(inTransaction) { + logger.trace({ types: Object.keys(data) }, `caching in transaction`) + for(const key in data) { + transactionCache[key] = transactionCache[key] || { } + Object.assign(transactionCache[key], data[key]) + + mutations[key] = mutations[key] || { } + Object.assign(mutations[key], data[key]) + } + } else { + return state.set(data) + } + }, + isInTransaction: () => inTransaction, + prefetch: (type, ids) => { + logger.trace({ type, ids }, `prefetching`) + return prefetch(type, ids) + }, + transaction: async(work) => { + if(inTransaction) { + await work() + } else { + logger.debug('entering transaction') + inTransaction = true + try { + await work() + if(Object.keys(mutations).length) { + logger.debug('committing transaction') + await state.set(mutations) + } else { + logger.debug('no mutations in transaction') + } + } finally { + inTransaction = false + transactionCache = { } + mutations = { } + } + } + } + } +} + export const initAuthCreds = (): AuthenticationCreds => { const identityKey = Curve.generateKeyPair() return { @@ -18,15 +107,6 @@ export const initAuthCreds = (): AuthenticationCreds => { } } -const KEY_MAP: { [T in keyof SignalDataTypeMap]: string } = { - 'pre-key': 'preKeys', - 'session': 'sessions', - 'sender-key': 'senderKeys', - 'app-state-sync-key': 'appStateSyncKeys', - 'app-state-sync-version': 'appStateVersions', - 'sender-key-memory': 'senderKeyMemory' -} - /** stores the full authentication state in a single JSON file */ export const useSingleFileAuthState = (filename: string): { state: AuthenticationState, saveState: () => void } => { // require fs here so that in case "fs" is not available -- the app does not crash