feat: mutex processing in a chat to preserve order of events

This commit is contained in:
Adhiraj Singh
2022-01-22 14:07:06 +05:30
parent a06f639774
commit 1f2a6641f3
4 changed files with 182 additions and 141 deletions

View File

@@ -2,7 +2,7 @@ import { Boom } from '@hapi/boom'
import { proto } from '../../WAProto' import { proto } from '../../WAProto'
import { AppStateChunk, Chat, ChatModification, ChatMutation, Contact, LTHashState, PresenceData, SocketConfig, WABusinessHoursConfig, WABusinessProfile, WAMediaUpload, WAPatchCreate, WAPatchName, WAPresence } from '../Types' import { AppStateChunk, Chat, ChatModification, ChatMutation, Contact, LTHashState, PresenceData, SocketConfig, WABusinessHoursConfig, WABusinessProfile, WAMediaUpload, WAPatchCreate, WAPatchName, WAPresence } from '../Types'
import { chatModificationToAppPatch, decodePatches, decodeSyncdSnapshot, encodeSyncdPatch, extractSyncdPatches, generateProfilePicture, newLTHashState, toNumber } from '../Utils' import { chatModificationToAppPatch, decodePatches, decodeSyncdSnapshot, encodeSyncdPatch, extractSyncdPatches, generateProfilePicture, newLTHashState, toNumber } from '../Utils'
import makeMutex from '../Utils/make-mutex' import { makeMutex } from '../Utils/make-mutex'
import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren, jidNormalizedUser, reduceBinaryNodeToDictionary, S_WHATSAPP_NET } from '../WABinary' import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren, jidNormalizedUser, reduceBinaryNodeToDictionary, S_WHATSAPP_NET } from '../WABinary'
import { makeMessagesSocket } from './messages-send' import { makeMessagesSocket } from './messages-send'

View File

@@ -3,6 +3,7 @@ import { proto } from '../../WAProto'
import { KEY_BUNDLE_TYPE } from '../Defaults' import { KEY_BUNDLE_TYPE } from '../Defaults'
import { Chat, GroupMetadata, MessageUserReceipt, ParticipantAction, SocketConfig, WAMessageStubType } from '../Types' import { Chat, GroupMetadata, MessageUserReceipt, ParticipantAction, SocketConfig, WAMessageStubType } from '../Types'
import { decodeMessageStanza, downloadAndProcessHistorySyncNotification, encodeBigEndian, generateSignalPubKey, toNumber, xmppPreKey, xmppSignedPreKey } from '../Utils' import { decodeMessageStanza, downloadAndProcessHistorySyncNotification, encodeBigEndian, generateSignalPubKey, toNumber, xmppPreKey, xmppSignedPreKey } from '../Utils'
import { makeKeyedMutex } from '../Utils/make-mutex'
import { areJidsSameUser, BinaryNode, BinaryNodeAttributes, getAllBinaryNodeChildren, getBinaryNodeChildren, isJidGroup, jidDecode, jidEncode, jidNormalizedUser } from '../WABinary' import { areJidsSameUser, BinaryNode, BinaryNodeAttributes, getAllBinaryNodeChildren, getBinaryNodeChildren, isJidGroup, jidDecode, jidEncode, jidNormalizedUser } from '../WABinary'
import { makeChatsSocket } from './chats' import { makeChatsSocket } from './chats'
import { extractGroupMetadata } from './groups' import { extractGroupMetadata } from './groups'
@@ -37,6 +38,9 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
resyncMainAppState, resyncMainAppState,
} = sock } = sock
/** the mutex ensures that the notifications (receipts, messages etc.) are processed in order */
const processingMutex = makeKeyedMutex()
const msgRetryMap = config.msgRetryCounterMap || { } const msgRetryMap = config.msgRetryCounterMap || { }
const historyCache = new Set<string>() const historyCache = new Set<string>()
@@ -338,24 +342,30 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
} }
// recv a message // recv a message
ws.on('CB:message', async(stanza: BinaryNode) => { ws.on('CB:message', (stanza: BinaryNode) => {
const msg = await decodeMessageStanza(stanza, authState) const { fullMessage: msg, decryptionTask } = decodeMessageStanza(stanza, authState)
// message failed to decrypt processingMutex.mutex(
if(msg.messageStubType === proto.WebMessageInfo.WebMessageInfoStubType.CIPHERTEXT) { msg.key.remoteJid!,
logger.error( async() => {
{ msgId: msg.key.id, params: msg.messageStubParameters }, await decryptionTask
'failure in decrypting message' // message failed to decrypt
) if(msg.messageStubType === proto.WebMessageInfo.WebMessageInfoStubType.CIPHERTEXT) {
await sendRetryRequest(stanza) logger.error(
} else { { msgId: msg.key.id, params: msg.messageStubParameters },
await sendMessageAck(stanza, { class: 'receipt' }) 'failure in decrypting message'
// no type in the receipt => message delivered )
await sendReceipt(msg.key.remoteJid!, msg.key.participant, [msg.key.id!], undefined) await sendRetryRequest(stanza)
logger.debug({ msg: msg.key }, 'sent delivery receipt') } else {
} await sendMessageAck(stanza, { class: 'receipt' })
// no type in the receipt => message delivered
msg.key.remoteJid = jidNormalizedUser(msg.key.remoteJid!) await sendReceipt(msg.key.remoteJid!, msg.key.participant, [msg.key.id!], undefined)
ev.emit('messages.upsert', { messages: [msg], type: stanza.attrs.offline ? 'append' : 'notify' }) logger.debug({ msg: msg.key }, 'sent delivery receipt')
}
msg.key.remoteJid = jidNormalizedUser(msg.key.remoteJid!)
ev.emit('messages.upsert', { messages: [msg], type: stanza.attrs.offline ? 'append' : 'notify' })
}
)
}) })
ws.on('CB:ack,class:message', async(node: BinaryNode) => { ws.on('CB:ack,class:message', async(node: BinaryNode) => {
@@ -428,82 +438,92 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
participant: attrs.participant participant: attrs.participant
} }
const status = getStatusFromReceiptType(attrs.type) await processingMutex.mutex(
if( remoteJid,
typeof status !== 'undefined' && async() => {
( const status = getStatusFromReceiptType(attrs.type)
// basically, we only want to know when a message from us has been delivered to/read by the other person if(
// or another device of ours has read some messages typeof status !== 'undefined' &&
status > proto.WebMessageInfo.WebMessageInfoStatus.DELIVERY_ACK || (
!isNodeFromMe // basically, we only want to know when a message from us has been delivered to/read by the other person
) // or another device of ours has read some messages
) { status > proto.WebMessageInfo.WebMessageInfoStatus.DELIVERY_ACK ||
if(isJidGroup(remoteJid)) { !isNodeFromMe
const updateKey: keyof MessageUserReceipt = status === proto.WebMessageInfo.WebMessageInfoStatus.DELIVERY_ACK ? 'receiptTimestamp' : 'readTimestamp' )
ev.emit( ) {
'message-receipt.update', if(isJidGroup(remoteJid)) {
ids.map(id => ({ const updateKey: keyof MessageUserReceipt = status === proto.WebMessageInfo.WebMessageInfoStatus.DELIVERY_ACK ? 'receiptTimestamp' : 'readTimestamp'
key: { ...key, id }, ev.emit(
receipt: { 'message-receipt.update',
userJid: jidNormalizedUser(attrs.participant), ids.map(id => ({
[updateKey]: +attrs.t key: { ...key, id },
} receipt: {
})) userJid: jidNormalizedUser(attrs.participant),
) [updateKey]: +attrs.t
} else { }
ev.emit( }))
'messages.update', )
ids.map(id => ({ } else {
key: { ...key, id }, ev.emit(
update: { status } 'messages.update',
})) ids.map(id => ({
) key: { ...key, id },
} update: { status }
}))
)
}
}
if(attrs.type === 'retry') {
// correctly set who is asking for the retry
key.participant = key.participant || attrs.from
if(key.fromMe) {
try {
logger.debug({ attrs }, 'recv retry request')
await sendMessagesAgain(key, ids)
} catch(error) {
logger.error({ key, ids, trace: error.stack }, 'error in sending message again')
shouldAck = false
} }
} else {
logger.info({ attrs, key }, 'recv retry for not fromMe message')
}
}
if(shouldAck) { if(attrs.type === 'retry') {
await sendMessageAck(node, { class: 'receipt', type: attrs.type }) // correctly set who is asking for the retry
} key.participant = key.participant || attrs.from
if(key.fromMe) {
try {
logger.debug({ attrs }, 'recv retry request')
await sendMessagesAgain(key, ids)
} catch(error) {
logger.error({ key, ids, trace: error.stack }, 'error in sending message again')
shouldAck = false
}
} else {
logger.info({ attrs, key }, 'recv retry for not fromMe message')
}
}
if(shouldAck) {
await sendMessageAck(node, { class: 'receipt', type: attrs.type })
}
}
)
} }
ws.on('CB:receipt', handleReceipt) ws.on('CB:receipt', handleReceipt)
ws.on('CB:notification', async(node: BinaryNode) => { ws.on('CB:notification', async(node: BinaryNode) => {
await sendMessageAck(node, { class: 'notification', type: node.attrs.type }) const remoteJid = node.attrs.from
processingMutex.mutex(
const msg = processNotification(node) remoteJid,
if(msg) { () => {
const fromMe = areJidsSameUser(node.attrs.participant || node.attrs.from, authState.creds.me!.id) const msg = processNotification(node)
msg.key = { if(msg) {
remoteJid: node.attrs.from, const fromMe = areJidsSameUser(node.attrs.participant || node.attrs.from, authState.creds.me!.id)
fromMe, msg.key = {
participant: node.attrs.participant, remoteJid: node.attrs.from,
id: node.attrs.id, fromMe,
...(msg.key || {}) participant: node.attrs.participant,
id: node.attrs.id,
...(msg.key || {})
}
msg.messageTimestamp = +node.attrs.t
const fullMsg = proto.WebMessageInfo.fromObject(msg)
ev.emit('messages.upsert', { messages: [fullMsg], type: 'append' })
}
} }
msg.messageTimestamp = +node.attrs.t )
const fullMsg = proto.WebMessageInfo.fromObject(msg) await sendMessageAck(node, { class: 'notification', type: node.attrs.type })
ev.emit('messages.upsert', { messages: [fullMsg], type: 'append' })
}
}) })
ev.on('messages.upsert', async({ messages, type }) => { ev.on('messages.upsert', async({ messages, type }) => {
@@ -520,7 +540,11 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
} }
} }
await processMessage(msg, chat) await processingMutex.mutex(
'p-' + chat.id!,
() => processMessage(msg, chat)
)
if(!!msg.message && !msg.message!.protocolMessage) { if(!!msg.message && !msg.message!.protocolMessage) {
chat.conversationTimestamp = toNumber(msg.messageTimestamp) chat.conversationTimestamp = toNumber(msg.messageTimestamp)
if(!msg.key.fromMe) { if(!msg.key.fromMe) {

View File

@@ -7,7 +7,7 @@ import { decryptGroupSignalProto, decryptSignalProto, processSenderKeyMessage }
type MessageType = 'chat' | 'peer_broadcast' | 'other_broadcast' | 'group' | 'direct_peer_status' | 'other_status' type MessageType = 'chat' | 'peer_broadcast' | 'other_broadcast' | 'group' | 'direct_peer_status' | 'other_status'
export const decodeMessageStanza = async(stanza: BinaryNode, auth: AuthenticationState) => { export const decodeMessageStanza = (stanza: BinaryNode, auth: AuthenticationState) => {
//const deviceIdentity = (stanza.content as BinaryNodeM[])?.find(m => m.tag === 'device-identity') //const deviceIdentity = (stanza.content as BinaryNodeM[])?.find(m => m.tag === 'device-identity')
//const deviceIdentityBytes = deviceIdentity ? deviceIdentity.content as Buffer : undefined //const deviceIdentityBytes = deviceIdentity ? deviceIdentity.content as Buffer : undefined
@@ -81,48 +81,51 @@ export const decodeMessageStanza = async(stanza: BinaryNode, auth: Authenticatio
fullMessage.status = proto.WebMessageInfo.WebMessageInfoStatus.SERVER_ACK fullMessage.status = proto.WebMessageInfo.WebMessageInfoStatus.SERVER_ACK
} }
if(Array.isArray(stanza.content)) { return {
for(const { tag, attrs, content } of stanza.content) { fullMessage,
if(tag !== 'enc') { decryptionTask: (async() => {
continue if(Array.isArray(stanza.content)) {
for(const { tag, attrs, content } of stanza.content) {
if(tag !== 'enc') {
continue
}
if(!(content instanceof Uint8Array)) {
continue
}
let msgBuffer: Buffer
try {
const e2eType = attrs.type
switch (e2eType) {
case 'skmsg':
msgBuffer = await decryptGroupSignalProto(sender, author, content, auth)
break
case 'pkmsg':
case 'msg':
const user = isJidUser(sender) ? sender : author
msgBuffer = await decryptSignalProto(user, e2eType, content as Buffer, auth)
break
}
let msg: proto.IMessage = proto.Message.decode(unpadRandomMax16(msgBuffer))
msg = msg.deviceSentMessage?.message || msg
if(msg.senderKeyDistributionMessage) {
await processSenderKeyMessage(author, msg.senderKeyDistributionMessage, auth)
}
if(fullMessage.message) {
Object.assign(fullMessage.message, msg)
} else {
fullMessage.message = msg
}
} catch(error) {
fullMessage.messageStubType = proto.WebMessageInfo.WebMessageInfoStubType.CIPHERTEXT
fullMessage.messageStubParameters = [error.message]
}
}
} }
})()
if(!(content instanceof Uint8Array)) {
continue
}
let msgBuffer: Buffer
try {
const e2eType = attrs.type
switch (e2eType) {
case 'skmsg':
msgBuffer = await decryptGroupSignalProto(sender, author, content, auth)
break
case 'pkmsg':
case 'msg':
const user = isJidUser(sender) ? sender : author
msgBuffer = await decryptSignalProto(user, e2eType, content as Buffer, auth)
break
}
let msg: proto.IMessage = proto.Message.decode(unpadRandomMax16(msgBuffer))
msg = msg.deviceSentMessage?.message || msg
if(msg.senderKeyDistributionMessage) {
await processSenderKeyMessage(author, msg.senderKeyDistributionMessage, auth)
}
if(fullMessage.message) {
Object.assign(fullMessage.message, msg)
} else {
fullMessage.message = msg
}
} catch(error) {
fullMessage.messageStubType = proto.WebMessageInfo.WebMessageInfoStubType.CIPHERTEXT
fullMessage.messageStubParameters = [error.message]
}
}
} }
return fullMessage
} }

View File

@@ -1,22 +1,36 @@
export const makeMutex = () => {
export default () => {
let task = Promise.resolve() as Promise<any> let task = Promise.resolve() as Promise<any>
return { return {
mutex<T>(code: () => Promise<T>):Promise<T> { mutex<T>(code: () => Promise<T> | T): Promise<T> {
task = (async() => { task = (async() => {
// wait for the previous task to complete // wait for the previous task to complete
// if there is an error, we swallow so as to not block the queue // if there is an error, we swallow so as to not block the queue
try { try {
await task await task
} catch{ } } catch{ }
// execute the current task // execute the current task
return code() return code()
})() })()
// we replace the existing task, appending the new piece of execution to it // we replace the existing task, appending the new piece of execution to it
// so the next task will have to wait for this one to finish // so the next task will have to wait for this one to finish
return task return task
}, },
}
}
export type Mutex = ReturnType<typeof makeMutex>
export const makeKeyedMutex = () => {
const map: { [id: string]: Mutex } = {}
return {
mutex<T>(key: string, task: () => Promise<T> | T): Promise<T> {
if(!map[key]) {
map[key] = makeMutex()
}
return map[key].mutex(task)
}
} }
} }