Merge pull request #2472 from moskoweb/poll

Repull:  support poll message
This commit is contained in:
Adhiraj Singh
2023-03-02 21:17:46 +04:00
committed by GitHub
9 changed files with 297 additions and 59 deletions

View File

@@ -1,6 +1,6 @@
import { Boom } from '@hapi/boom' import { Boom } from '@hapi/boom'
import NodeCache from 'node-cache' import NodeCache from 'node-cache'
import makeWASocket, { AnyMessageContent, delay, DisconnectReason, fetchLatestBaileysVersion, makeCacheableSignalKeyStore, makeInMemoryStore, useMultiFileAuthState } from '../src' import makeWASocket, { AnyMessageContent, delay, DisconnectReason, fetchLatestBaileysVersion, getAggregateVotesInPollMessage, makeCacheableSignalKeyStore, makeInMemoryStore, proto, useMultiFileAuthState, WAMessageContent, WAMessageKey } from '../src'
import MAIN_LOGGER from '../src/Utils/logger' import MAIN_LOGGER from '../src/Utils/logger'
const logger = MAIN_LOGGER.child({ }) const logger = MAIN_LOGGER.child({ })
@@ -43,18 +43,8 @@ const startSock = async() => {
// ignore all broadcast messages -- to receive the same // ignore all broadcast messages -- to receive the same
// comment the line below out // comment the line below out
// shouldIgnoreJid: jid => isJidBroadcast(jid), // shouldIgnoreJid: jid => isJidBroadcast(jid),
// implement to handle retries // implement to handle retries & poll updates
getMessage: async key => { getMessage,
if(store) {
const msg = await store.loadMessage(key.remoteJid!, key.id!)
return msg?.message || undefined
}
// only if store is present
return {
conversation: 'hello'
}
}
}) })
store?.bind(sock.ev) store?.bind(sock.ev)
@@ -126,7 +116,24 @@ const startSock = async() => {
// messages updated like status delivered, message deleted etc. // messages updated like status delivered, message deleted etc.
if(events['messages.update']) { if(events['messages.update']) {
console.log(events['messages.update']) console.log(
JSON.stringify(events['messages.update'], undefined, 2)
)
for(const { key, update } of events['messages.update']) {
if(update.pollUpdates) {
const pollCreation = await getMessage(key)
if(pollCreation) {
console.log(
'got poll update, aggregation: ',
getAggregateVotesInPollMessage({
message: pollCreation,
pollUpdates: update.pollUpdates,
})
)
}
}
}
} }
if(events['message-receipt.update']) { if(events['message-receipt.update']) {
@@ -165,6 +172,16 @@ const startSock = async() => {
) )
return sock return sock
async function getMessage(key: WAMessageKey): Promise<WAMessageContent | undefined> {
if(store) {
const msg = await store.loadMessage(key.remoteJid!, key.id!)
return msg?.message || undefined
}
// only if store is present
return proto.Message.fromObject({})
}
} }
startSock() startSock()

View File

@@ -245,8 +245,10 @@ export const makeChatsSocket = (config: SocketConfig) => {
const website = getBinaryNodeChild(profiles, 'website') const website = getBinaryNodeChild(profiles, 'website')
const email = getBinaryNodeChild(profiles, 'email') const email = getBinaryNodeChild(profiles, 'email')
const category = getBinaryNodeChild(getBinaryNodeChild(profiles, 'categories'), 'category') const category = getBinaryNodeChild(getBinaryNodeChild(profiles, 'categories'), 'category')
const business_hours = getBinaryNodeChild(profiles, 'business_hours') const businessHours = getBinaryNodeChild(profiles, 'business_hours')
const business_hours_config = business_hours && getBinaryNodeChildren(business_hours, 'business_hours_config') const businessHoursConfig = businessHours
? getBinaryNodeChildren(businessHours, 'business_hours_config')
: undefined
const websiteStr = website?.content?.toString() const websiteStr = website?.content?.toString()
return { return {
wid: profiles.attrs?.jid, wid: profiles.attrs?.jid,
@@ -255,9 +257,9 @@ export const makeChatsSocket = (config: SocketConfig) => {
website: websiteStr ? [websiteStr] : [], website: websiteStr ? [websiteStr] : [],
email: email?.content?.toString(), email: email?.content?.toString(),
category: category?.content?.toString(), category: category?.content?.toString(),
business_hours: { 'business_hours': {
timezone: business_hours?.attrs?.timezone, timezone: businessHours?.attrs?.timezone,
business_config: business_hours_config?.map(({ attrs }) => attrs as unknown as WABusinessHoursConfig) 'business_config': businessHoursConfig?.map(({ attrs }) => attrs as unknown as WABusinessHoursConfig)
} }
} }
} }
@@ -599,7 +601,7 @@ export const makeChatsSocket = (config: SocketConfig) => {
attrs: { attrs: {
name, name,
version: (state.version - 1).toString(), version: (state.version - 1).toString(),
return_snapshot: 'false' 'return_snapshot': 'false'
}, },
content: [ content: [
{ {
@@ -762,6 +764,7 @@ export const makeChatsSocket = (config: SocketConfig) => {
keyStore: authState.keys, keyStore: authState.keys,
logger, logger,
options: config.options, options: config.options,
getMessage: config.getMessage,
} }
) )
]) ])

View File

@@ -81,6 +81,14 @@ type WithDimensions = {
height?: number height?: number
} }
export type PollMessageOptions = {
name: string
selectableCount?: number
values: string[]
/** 32 byte message secret to encrypt poll selections */
messageSecret?: Uint8Array
}
export type MediaType = keyof typeof MEDIA_HKDF_KEY_MAPPING export type MediaType = keyof typeof MEDIA_HKDF_KEY_MAPPING
export type AnyMediaMessageContent = ( export type AnyMediaMessageContent = (
({ ({
@@ -129,6 +137,9 @@ export type AnyRegularMessageContent = (
} }
& Mentionable & Buttonable & Templatable & Listable) & Mentionable & Buttonable & Templatable & Listable)
| AnyMediaMessageContent | AnyMediaMessageContent
| ({
poll: PollMessageOptions
} & Mentionable & Buttonable & Templatable)
| { | {
contacts: { contacts: {
displayName?: string displayName?: string

View File

@@ -41,6 +41,13 @@ export const BufferJSON = {
} }
} }
export const getKeyAuthor = (
key: proto.IMessageKey | undefined | null,
meId: string = 'me'
) => (
(key?.fromMe ? meId : key?.participant || key?.remoteJid) || ''
)
export const writeRandomPadMax16 = (msg: Uint8Array) => { export const writeRandomPadMax16 = (msg: Uint8Array) => {
const pad = randomBytes(1) const pad = randomBytes(1)
pad[0] &= 0xf pad[0] &= 0xf

View File

@@ -1,5 +1,6 @@
import { Boom } from '@hapi/boom' import { Boom } from '@hapi/boom'
import axios from 'axios' import axios from 'axios'
import { randomBytes } from 'crypto'
import { promises as fs } from 'fs' import { promises as fs } from 'fs'
import { Logger } from 'pino' import { Logger } from 'pino'
import { proto } from '../../WAProto' import { proto } from '../../WAProto'
@@ -18,13 +19,13 @@ import {
WAMediaUpload, WAMediaUpload,
WAMessage, WAMessage,
WAMessageContent, WAMessageContent,
WAMessageKey,
WAMessageStatus, WAMessageStatus,
WAProto, WAProto,
WATextMessage, WATextMessage,
} from '../Types' } from '../Types'
import { isJidGroup, jidNormalizedUser } from '../WABinary' import { isJidGroup, jidNormalizedUser } from '../WABinary'
import { generateMessageID, unixTimestampSeconds } from './generics' import { sha256 } from './crypto'
import { generateMessageID, getKeyAuthor, unixTimestampSeconds } from './generics'
import { downloadContentFromMessage, encryptedStream, generateThumbnail, getAudioDuration, MediaDownloadOptions } from './messages-media' import { downloadContentFromMessage, encryptedStream, generateThumbnail, getAudioDuration, MediaDownloadOptions } from './messages-media'
type MediaUploadData = { type MediaUploadData = {
@@ -172,7 +173,7 @@ export const prepareWAMessageMedia = async(
const { const {
thumbnail, thumbnail,
originalImageDimensions originalImageDimensions
} = await generateThumbnail(bodyPath!, mediaType as any, options) } = await generateThumbnail(bodyPath!, mediaType as 'image' | 'video', options)
uploadData.jpegThumbnail = thumbnail uploadData.jpegThumbnail = thumbnail
if(!uploadData.width && originalImageDimensions) { if(!uploadData.width && originalImageDimensions) {
uploadData.width = originalImageDimensions.width uploadData.width = originalImageDimensions.width
@@ -379,6 +380,33 @@ export const generateWAMessageContent = async(
}) })
} else if('listReply' in message) { } else if('listReply' in message) {
m.listResponseMessage = { ...message.listReply } m.listResponseMessage = { ...message.listReply }
} else if('poll' in message) {
message.poll.selectableCount ||= 0
if(!Array.isArray(message.poll.values)) {
throw new Boom('Invalid poll values', { statusCode: 400 })
}
if(
message.poll.selectableCount < 0
|| message.poll.selectableCount > message.poll.values.length
) {
throw new Boom(
`poll.selectableCount in poll should be >= 0 and <= ${message.poll.values.length}`,
{ statusCode: 400 }
)
}
m.messageContextInfo = {
// encKey
messageSecret: message.poll.messageSecret || randomBytes(32),
}
m.pollCreationMessage = {
name: message.poll.name,
selectableOptionsCount: message.poll.selectableCount,
options: message.poll.values.map(optionName => ({ optionName })),
}
} else { } else {
m = await prepareWAMessageMedia( m = await prepareWAMessageMedia(
message, message,
@@ -468,9 +496,11 @@ export const generateWAMessageFromContent = (
message: WAMessageContent, message: WAMessageContent,
options: MessageGenerationOptionsFromContent options: MessageGenerationOptionsFromContent
) => { ) => {
// set timestamp to now
// if not specified
if(!options.timestamp) { if(!options.timestamp) {
options.timestamp = new Date() options.timestamp = new Date()
} // set timestamp to now }
const key = Object.keys(message)[0] const key = Object.keys(message)[0]
const timestamp = unixTimestampSeconds(options.timestamp) const timestamp = unixTimestampSeconds(options.timestamp)
@@ -573,31 +603,31 @@ export const getContentType = (content: WAProto.IMessage | undefined) => {
* @returns * @returns
*/ */
export const normalizeMessageContent = (content: WAMessageContent | null | undefined): WAMessageContent | undefined => { export const normalizeMessageContent = (content: WAMessageContent | null | undefined): WAMessageContent | undefined => {
if(!content) { if(!content) {
return undefined return undefined
} }
// set max iterations to prevent an infinite loop // set max iterations to prevent an infinite loop
for(let i = 0;i < 5;i++) { for(let i = 0;i < 5;i++) {
const inner = getFutureProofMessage(content) const inner = getFutureProofMessage(content)
if(!inner) { if(!inner) {
break break
} }
content = inner.message content = inner.message
} }
return content! return content!
function getFutureProofMessage(message: typeof content) { function getFutureProofMessage(message: typeof content) {
return ( return (
message?.ephemeralMessage message?.ephemeralMessage
|| message?.viewOnceMessage || message?.viewOnceMessage
|| message?.documentWithCaptionMessage || message?.documentWithCaptionMessage
|| message?.viewOnceMessageV2 || message?.viewOnceMessageV2
|| message?.editedMessage || message?.editedMessage
) )
} }
} }
/** /**
@@ -664,10 +694,6 @@ export const updateMessageWithReceipt = (msg: Pick<WAMessage, 'userReceipt'>, re
} }
} }
const getKeyAuthor = (key: WAMessageKey | undefined | null) => (
(key?.fromMe ? 'me' : key?.participant || key?.remoteJid) || ''
)
/** Update the message with a new reaction */ /** Update the message with a new reaction */
export const updateMessageWithReaction = (msg: Pick<WAMessage, 'reactions'>, reaction: proto.IReaction) => { export const updateMessageWithReaction = (msg: Pick<WAMessage, 'reactions'>, reaction: proto.IReaction) => {
const authorID = getKeyAuthor(reaction.key) const authorID = getKeyAuthor(reaction.key)
@@ -681,6 +707,73 @@ export const updateMessageWithReaction = (msg: Pick<WAMessage, 'reactions'>, rea
msg.reactions = reactions msg.reactions = reactions
} }
/** Update the message with a new poll update */
export const updateMessageWithPollUpdate = (
msg: Pick<WAMessage, 'pollUpdates'>,
update: proto.IPollUpdate
) => {
const authorID = getKeyAuthor(update.pollUpdateMessageKey)
const reactions = (msg.pollUpdates || [])
.filter(r => getKeyAuthor(r.pollUpdateMessageKey) !== authorID)
if(update.vote?.selectedOptions?.length) {
reactions.push(update)
}
msg.pollUpdates = reactions
}
type VoteAggregation = {
name: string
voters: string[]
}
/**
* Aggregates all poll updates in a poll.
* @param msg the poll creation message
* @param meId your jid
* @returns A list of options & their voters
*/
export function getAggregateVotesInPollMessage(
{ message, pollUpdates }: Pick<WAMessage, 'pollUpdates' | 'message'>,
meId?: string
) {
const opts = message?.pollCreationMessage?.options || []
const voteHashMap = opts.reduce((acc, opt) => {
const hash = sha256(Buffer.from(opt.optionName || '')).toString()
acc[hash] = {
name: opt.optionName || '',
voters: []
}
return acc
}, {} as { [_: string]: VoteAggregation })
for(const update of pollUpdates || []) {
const { vote } = update
if(!vote) {
continue
}
for(const option of vote.selectedOptions || []) {
const hash = option.toString()
let data = voteHashMap[hash]
if(!data) {
voteHashMap[hash] = {
name: 'Unknown',
voters: []
}
data = voteHashMap[hash]
}
voteHashMap[hash].voters.push(
getKeyAuthor(update.pollUpdateMessageKey, meId)
)
}
}
return Object.values(voteHashMap)
}
/** Given a list of message keys, aggregates them by chat & sender. Useful for sending read receipts in bulk */ /** Given a list of message keys, aggregates them by chat & sender. Useful for sending read receipts in bulk */
export const aggregateMessageKeysNotFromMe = (keys: proto.IMessageKey[]) => { export const aggregateMessageKeysNotFromMe = (keys: proto.IMessageKey[]) => {
const keyMap: { [id: string]: { jid: string, participant: string | undefined, messageIds: string[] } } = { } const keyMap: { [id: string]: { jid: string, participant: string | undefined, messageIds: string[] } } = { }

View File

@@ -1,17 +1,21 @@
import { AxiosRequestConfig } from 'axios' import { AxiosRequestConfig } from 'axios'
import type { Logger } from 'pino' import type { Logger } from 'pino'
import { proto } from '../../WAProto' import { proto } from '../../WAProto'
import { AuthenticationCreds, BaileysEventEmitter, Chat, GroupMetadata, ParticipantAction, SignalKeyStoreWithTransaction, WAMessageStubType } from '../Types' import { AuthenticationCreds, BaileysEventEmitter, Chat, GroupMetadata, ParticipantAction, SignalKeyStoreWithTransaction, SocketConfig, WAMessageStubType } from '../Types'
import { downloadAndProcessHistorySyncNotification, getContentType, normalizeMessageContent, toNumber } from '../Utils' import { getContentType, normalizeMessageContent } from '../Utils/messages'
import { areJidsSameUser, isJidBroadcast, isJidStatusBroadcast, jidNormalizedUser } from '../WABinary' import { areJidsSameUser, isJidBroadcast, isJidStatusBroadcast, jidNormalizedUser } from '../WABinary'
import { aesDecryptGCM, hmacSign } from './crypto'
import { getKeyAuthor, toNumber } from './generics'
import { downloadAndProcessHistorySyncNotification } from './history'
type ProcessMessageContext = { type ProcessMessageContext = {
shouldProcessHistoryMsg: boolean shouldProcessHistoryMsg: boolean
creds: AuthenticationCreds creds: AuthenticationCreds
keyStore: SignalKeyStoreWithTransaction keyStore: SignalKeyStoreWithTransaction
ev: BaileysEventEmitter ev: BaileysEventEmitter
getMessage: SocketConfig['getMessage']
logger?: Logger logger?: Logger
options: AxiosRequestConfig<any> options: AxiosRequestConfig<{}>
} }
const REAL_MSG_STUB_TYPES = new Set([ const REAL_MSG_STUB_TYPES = new Set([
@@ -33,7 +37,14 @@ export const cleanMessage = (message: proto.IWebMessageInfo, meId: string) => {
const content = normalizeMessageContent(message.message) const content = normalizeMessageContent(message.message)
// if the message has a reaction, ensure fromMe & remoteJid are from our perspective // if the message has a reaction, ensure fromMe & remoteJid are from our perspective
if(content?.reactionMessage) { if(content?.reactionMessage) {
const msgKey = content.reactionMessage.key! normaliseKey(content.reactionMessage.key!)
}
if(content?.pollUpdateMessage) {
normaliseKey(content.pollUpdateMessage.pollCreationMessageKey!)
}
function normaliseKey(msgKey: proto.IMessageKey) {
// if the reaction is from another user // if the reaction is from another user
// we've to correctly map the key to this user's perspective // we've to correctly map the key to this user's perspective
if(!message.key.fromMe) { if(!message.key.fromMe) {
@@ -66,6 +77,7 @@ export const isRealMessage = (message: proto.IWebMessageInfo, meId: string) => {
&& hasSomeContent && hasSomeContent
&& !normalizedContent?.protocolMessage && !normalizedContent?.protocolMessage
&& !normalizedContent?.reactionMessage && !normalizedContent?.reactionMessage
&& !normalizedContent?.pollUpdateMessage
} }
export const shouldIncrementChatUnread = (message: proto.IWebMessageInfo) => ( export const shouldIncrementChatUnread = (message: proto.IWebMessageInfo) => (
@@ -88,6 +100,54 @@ export const getChatId = ({ remoteJid, participant, fromMe }: proto.IMessageKey)
return remoteJid! return remoteJid!
} }
type PollContext = {
/** normalised jid of the person that created the poll */
pollCreatorJid: string
/** ID of the poll creation message */
pollMsgId: string
/** poll creation message enc key */
pollEncKey: Uint8Array
/** jid of the person that voted */
voterJid: string
}
/**
* Decrypt a poll vote
* @param vote encrypted vote
* @param ctx additional info about the poll required for decryption
* @returns list of SHA256 options
*/
export function decryptPollVote(
{ encPayload, encIv }: proto.Message.IPollEncValue,
{
pollCreatorJid,
pollMsgId,
pollEncKey,
voterJid,
}: PollContext
) {
const sign = Buffer.concat(
[
toBinary(pollMsgId),
toBinary(pollCreatorJid),
toBinary(voterJid),
toBinary('Poll Vote'),
new Uint8Array([1])
]
)
const key0 = hmacSign(pollEncKey, new Uint8Array(32), 'sha256')
const decKey = hmacSign(sign, key0, 'sha256')
const aad = toBinary(`${pollMsgId}\u0000${voterJid}`)
const decrypted = aesDecryptGCM(encPayload!, decKey, encIv!, aad)
return proto.Message.PollVoteMessage.decode(decrypted)
function toBinary(txt: string) {
return Buffer.from(txt)
}
}
const processMessage = async( const processMessage = async(
message: proto.IWebMessageInfo, message: proto.IWebMessageInfo,
{ {
@@ -96,7 +156,8 @@ const processMessage = async(
creds, creds,
keyStore, keyStore,
logger, logger,
options options,
getMessage
}: ProcessMessageContext }: ProcessMessageContext
) => { ) => {
const meId = creds.me!.id const meId = creds.me!.id
@@ -273,6 +334,52 @@ const processMessage = async(
emitGroupUpdate({ inviteCode: code }) emitGroupUpdate({ inviteCode: code })
break break
} }
} else if(content?.pollUpdateMessage) {
const creationMsgKey = content.pollUpdateMessage.pollCreationMessageKey!
// we need to fetch the poll creation message to get the poll enc key
const pollMsg = await getMessage(creationMsgKey)
if(pollMsg) {
const meIdNormalised = jidNormalizedUser(meId)
const pollCreatorJid = getKeyAuthor(creationMsgKey, meIdNormalised)
const voterJid = getKeyAuthor(message.key!, meIdNormalised)
const pollEncKey = pollMsg.messageContextInfo?.messageSecret!
try {
const voteMsg = decryptPollVote(
content.pollUpdateMessage.vote!,
{
pollEncKey,
pollCreatorJid,
pollMsgId: creationMsgKey.id!,
voterJid,
}
)
ev.emit('messages.update', [
{
key: creationMsgKey,
update: {
pollUpdates: [
{
pollUpdateMessageKey: message.key,
vote: voteMsg,
senderTimestampMs: message.messageTimestamp,
}
]
}
}
])
} catch(err) {
logger?.warn(
{ err, creationMsgKey },
'failed to decrypt poll vote'
)
}
} else {
logger?.warn(
{ creationMsgKey },
'poll creation message not found, cannot decrypt update'
)
}
} }
if(Object.keys(chat).length > 1) { if(Object.keys(chat).length > 1) {

View File

@@ -3452,7 +3452,7 @@ json-stable-stringify-without-jsonify@^1.0.1:
json5@2.x, json5@^2.2.1: json5@2.x, json5@^2.2.1:
version "2.2.3" version "2.2.3"
resolved "https://registry.yarnpkg.com/json5/-/json5-2.2.3.tgz#78cd6f1a19bdc12b73db5ad0c61efd66c1e29283" resolved "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz"
integrity sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg== integrity sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==
jsonc-parser@^3.0.0: jsonc-parser@^3.0.0:
@@ -3496,7 +3496,7 @@ levn@~0.3.0:
"libsignal@git+https://github.com/adiwajshing/libsignal-node": "libsignal@git+https://github.com/adiwajshing/libsignal-node":
version "2.0.1" version "2.0.1"
resolved "git+https://github.com/adiwajshing/libsignal-node.git#11dbd962ea108187c79a7c46fe4d6f790e23da97" resolved "git+ssh://git@github.com/adiwajshing/libsignal-node.git#11dbd962ea108187c79a7c46fe4d6f790e23da97"
dependencies: dependencies:
curve25519-js "^0.0.4" curve25519-js "^0.0.4"
protobufjs "6.8.8" protobufjs "6.8.8"