feat: add signal repository + tests

This commit is contained in:
Adhiraj Singh
2023-03-18 12:25:47 +05:30
parent 2eea17fe9f
commit fe1d0649b5
21 changed files with 500 additions and 206 deletions

View File

@@ -54,7 +54,7 @@ export const makeBusinessSocket = (config: SocketConfig) => {
tag: 'product_catalog',
attrs: {
jid,
allow_shop_source: 'true'
'allow_shop_source': 'true'
},
content: queryParamNodes
}
@@ -72,13 +72,13 @@ export const makeBusinessSocket = (config: SocketConfig) => {
to: S_WHATSAPP_NET,
type: 'get',
xmlns: 'w:biz:catalog',
smax_id: '35'
'smax_id': '35'
},
content: [
{
tag: 'collections',
attrs: {
biz_jid: jid,
'biz_jid': jid,
},
content: [
{
@@ -116,7 +116,7 @@ export const makeBusinessSocket = (config: SocketConfig) => {
to: S_WHATSAPP_NET,
type: 'get',
xmlns: 'fb:thrift_iq',
smax_id: '5'
'smax_id': '5'
},
content: [
{

View File

@@ -1,5 +1,5 @@
import { proto } from '../../WAProto'
import { GroupMetadata, ParticipantAction, SocketConfig, WAMessageKey, WAMessageStubType } from '../Types'
import { GroupMetadata, GroupParticipant, ParticipantAction, SocketConfig, WAMessageKey, WAMessageStubType } from '../Types'
import { generateMessageID, unixTimestampSeconds } from '../Utils'
import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren, getBinaryNodeChildString, jidEncode, jidNormalizedUser } from '../WABinary'
import { makeChatsSocket } from './chats'
@@ -278,7 +278,7 @@ export const extractGroupMetadata = (result: BinaryNode) => {
({ attrs }) => {
return {
id: attrs.jid,
admin: attrs.type || null as any,
admin: (attrs.type || null) as GroupParticipant['admin'],
}
}
),

View File

@@ -22,8 +22,9 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
ev,
authState,
ws,
query,
processingMutex,
signalRepository,
query,
upsertMessage,
resyncAppState,
onUnexpectedError,
@@ -543,7 +544,11 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
}
const handleMessage = async(node: BinaryNode) => {
const { fullMessage: msg, category, author, decrypt } = decryptMessageNode(node, authState)
const { fullMessage: msg, category, author, decrypt } = decryptMessageNode(
node,
authState.creds.me!.id,
signalRepository
)
if(shouldIgnoreJid(msg.key.remoteJid!)) {
logger.debug({ key: msg.key }, 'ignored message')
await sendMessageAck(node)

View File

@@ -4,7 +4,7 @@ import NodeCache from 'node-cache'
import { proto } from '../../WAProto'
import { DEFAULT_CACHE_TTLS, WA_DEFAULT_EPHEMERAL } from '../Defaults'
import { AnyMessageContent, MediaConnInfo, MessageReceiptType, MessageRelayOptions, MiscMessageGenerationOptions, SocketConfig, WAMessageKey } from '../Types'
import { aggregateMessageKeysNotFromMe, assertMediaContent, bindWaitForEvent, decryptMediaRetryData, encodeSignedDeviceIdentity, encodeWAMessage, encryptMediaRetryRequest, encryptSenderKeyMsgSignalProto, encryptSignalProto, extractDeviceJids, generateMessageID, generateWAMessage, getStatusCodeForMediaRetry, getUrlFromDirectPath, getWAUploadToServer, jidToSignalProtocolAddress, parseAndInjectE2ESessions, unixTimestampSeconds } from '../Utils'
import { aggregateMessageKeysNotFromMe, assertMediaContent, bindWaitForEvent, decryptMediaRetryData, encodeSignedDeviceIdentity, encodeWAMessage, encryptMediaRetryRequest, extractDeviceJids, generateMessageID, generateWAMessage, getStatusCodeForMediaRetry, getUrlFromDirectPath, getWAUploadToServer, parseAndInjectE2ESessions, unixTimestampSeconds } from '../Utils'
import { getUrlInfo } from '../Utils/link-preview'
import { areJidsSameUser, BinaryNode, BinaryNodeAttributes, getBinaryNodeChild, getBinaryNodeChildren, isJidGroup, isJidUser, jidDecode, jidEncode, jidNormalizedUser, JidWithDevice, S_WHATSAPP_NET } from '../WABinary'
import { makeGroupsSocket } from './groups'
@@ -22,6 +22,7 @@ export const makeMessagesSocket = (config: SocketConfig) => {
ev,
authState,
processingMutex,
signalRepository,
upsertMessage,
query,
fetchPrivacySettings,
@@ -215,10 +216,14 @@ export const makeMessagesSocket = (config: SocketConfig) => {
if(force) {
jidsRequiringFetch = jids
} else {
const addrs = jids.map(jid => jidToSignalProtocolAddress(jid).toString())
const addrs = jids.map(jid => (
signalRepository
.jidToSignalProtocolAddress(jid)
))
const sessions = await authState.keys.get('session', addrs)
for(const jid of jids) {
const signalId = jidToSignalProtocolAddress(jid).toString()
const signalId = signalRepository
.jidToSignalProtocolAddress(jid)
if(!sessions[signalId]) {
jidsRequiringFetch.push(jid)
}
@@ -247,7 +252,7 @@ export const makeMessagesSocket = (config: SocketConfig) => {
}
]
})
await parseAndInjectE2ESessions(result, authState)
await parseAndInjectE2ESessions(result, signalRepository)
didFetchNewSession = true
}
@@ -267,7 +272,8 @@ export const makeMessagesSocket = (config: SocketConfig) => {
const nodes = await Promise.all(
jids.map(
async jid => {
const { type, ciphertext } = await encryptSignalProto(jid, bytes, authState)
const { type, ciphertext } = await signalRepository
.encryptMessage({ jid, data: bytes })
if(type === 'pkmsg') {
shouldIncludeDeviceIdentity = true
}
@@ -365,11 +371,12 @@ export const makeMessagesSocket = (config: SocketConfig) => {
const patched = await patchMessageBeforeSending(message, devices.map(d => jidEncode(d.user, 's.whatsapp.net', d.device)))
const bytes = encodeWAMessage(patched)
const { ciphertext, senderKeyDistributionMessageKey } = await encryptSenderKeyMsgSignalProto(
destinationJid,
bytes,
meId,
authState
const { ciphertext, senderKeyDistributionMessage } = await signalRepository.encryptGroupMessage(
{
group: destinationJid,
data: bytes,
meId,
}
)
const senderKeyJids: string[] = []
@@ -390,7 +397,7 @@ export const makeMessagesSocket = (config: SocketConfig) => {
const senderKeyMsg: proto.IMessage = {
senderKeyDistributionMessage: {
axolotlSenderKeyDistributionMessage: senderKeyDistributionMessageKey,
axolotlSenderKeyDistributionMessage: senderKeyDistributionMessage,
groupId: destinationJid
}
}

View File

@@ -29,6 +29,7 @@ export const makeSocket = ({
transactionOpts,
qrTimeout,
options,
makeSignalRepository
}: SocketConfig) => {
const ws = new WebSocket(waWebSocketUrl, undefined, {
origin: DEFAULT_ORIGIN,
@@ -48,6 +49,7 @@ export const makeSocket = ({
const { creds } = authState
// add transaction capability
const keys = addTransactionCapability(authState.keys, logger, transactionOpts)
const signalRepository = makeSignalRepository({ creds, keys })
let lastDateRecv: Date
let epoch = 1
@@ -90,24 +92,26 @@ export const makeSocket = ({
}
/** log & process any unexpected errors */
const onUnexpectedError = (error: Error, msg: string) => {
const onUnexpectedError = (err: Error | Boom, msg: string) => {
logger.error(
{ trace: error.stack, output: (error as any).output },
{ err },
`unexpected error in '${msg}'`
)
}
/** await the next incoming message */
const awaitNextMessage = async(sendMsg?: Uint8Array) => {
const awaitNextMessage = async<T>(sendMsg?: Uint8Array) => {
if(ws.readyState !== ws.OPEN) {
throw new Boom('Connection Closed', { statusCode: DisconnectReason.connectionClosed })
throw new Boom('Connection Closed', {
statusCode: DisconnectReason.connectionClosed
})
}
let onOpen: (data: any) => void
let onOpen: (data: T) => void
let onClose: (err: Error) => void
const result = promiseTimeout<any>(connectTimeoutMs, (resolve, reject) => {
onOpen = (data: any) => resolve(data)
const result = promiseTimeout<T>(connectTimeoutMs, (resolve, reject) => {
onOpen = resolve
onClose = mapWebSocketError(reject)
ws.on('frame', onOpen)
ws.on('close', onClose)
@@ -132,11 +136,11 @@ export const makeSocket = ({
* @param json query that was sent
* @param timeoutMs timeout after which the promise will reject
*/
const waitForMessage = async(msgId: string, timeoutMs = defaultQueryTimeoutMs) => {
const waitForMessage = async<T>(msgId: string, timeoutMs = defaultQueryTimeoutMs) => {
let onRecv: (json) => void
let onErr: (err) => void
try {
const result = await promiseTimeout(timeoutMs,
const result = await promiseTimeout<T>(timeoutMs,
(resolve, reject) => {
onRecv = resolve
onErr = err => {
@@ -148,7 +152,7 @@ export const makeSocket = ({
ws.off('error', onErr)
},
)
return result as any
return result
} finally {
ws.off(`TAG:${msgId}`, onRecv!)
ws.off('close', onErr!) // if the socket closes, you'll never receive the message
@@ -186,7 +190,7 @@ export const makeSocket = ({
const init = proto.HandshakeMessage.encode(helloMsg).finish()
const result = await awaitNextMessage(init)
const result = await awaitNextMessage<Uint8Array>(init)
const handshake = proto.HandshakeMessage.decode(result)
logger.trace({ handshake }, 'handshake recv from WA Web')
@@ -591,6 +595,7 @@ export const makeSocket = ({
ws,
ev,
authState: { creds, keys },
signalRepository,
get user() {
return authState.creds.me
},