From 2b8256d56b234b01c5a0b9b569fe7c1239761a9a Mon Sep 17 00:00:00 2001 From: Adhiraj Singh Date: Sat, 11 Dec 2021 17:54:38 +0530 Subject: [PATCH] feat: cleaner auth state management + store SK keys !BREAKING_CHANGE --- src/Socket/chats.ts | 25 +++-- src/Socket/messages-recv.ts | 19 ++-- src/Socket/messages-send.ts | 204 +++++++++++++++++++++--------------- src/Socket/socket.ts | 8 +- src/Types/Auth.ts | 18 +++- src/Utils/auth-utils.ts | 142 +++++++++---------------- src/Utils/chat-utils.ts | 16 +-- src/Utils/signal.ts | 79 +++++++------- 8 files changed, 264 insertions(+), 247 deletions(-) diff --git a/src/Socket/chats.ts b/src/Socket/chats.ts index e9dbf26..73fd1f8 100644 --- a/src/Socket/chats.ts +++ b/src/Socket/chats.ts @@ -20,6 +20,11 @@ export const makeChatsSocket = (config: SocketConfig) => { const mutationMutex = makeMutex() + const getAppStateSyncKey = async(keyId: string) => { + const { [keyId]: key } = await authState.keys.get('app-state-sync-key', [keyId]) + return key + } + const interactiveQuery = async(userNodes: BinaryNode[], queryNode: BinaryNode) => { const result = await query({ tag: 'iq', @@ -175,7 +180,11 @@ export const makeChatsSocket = (config: SocketConfig) => { const states = { } as { [T in WAPatchName]: LTHashState } for(const name of collections) { - let state: LTHashState = fromScratch ? undefined : await authState.keys.getAppStateSyncVersion(name) + let state: LTHashState + if(!fromScratch) { + const result = await authState.keys.get('app-state-sync-version', [name]) + state = result[name] + } if(!state) state = newLTHashState() states[name] = state @@ -213,16 +222,16 @@ export const makeChatsSocket = (config: SocketConfig) => { const name = key as WAPatchName const { patches, snapshot } = decoded[name] if(snapshot) { - const newState = await decodeSyncdSnapshot(name, snapshot, authState.keys.getAppStateSyncKey) + const newState = await decodeSyncdSnapshot(name, snapshot, getAppStateSyncKey) states[name] = newState logger.info(`restored state of ${name} from snapshot to v${newState.version}`) } // only process if there are syncd patches if(patches.length) { - const { newMutations, state: newState } = await decodePatches(name, patches, states[name], authState.keys.getAppStateSyncKey, true) + const { newMutations, state: newState } = await decodePatches(name, patches, states[name], getAppStateSyncKey, true) - await authState.keys.setAppStateSyncVersion(name, newState) + await authState.keys.set({ 'app-state-sync-version': { [name]: newState } }) logger.info(`synced ${name} to v${newState.version}`) if(newMutations.length) { @@ -415,12 +424,12 @@ export const makeChatsSocket = (config: SocketConfig) => { logger.debug({ patch: patchCreate }, 'applying app patch') await resyncAppState([name]) - const initial = await authState.keys.getAppStateSyncVersion(name) + const { [name]: initial } = await authState.keys.get('app-state-sync-version', [name]) const { patch, state } = await encodeSyncdPatch( patchCreate, authState.creds.myAppStateKeyId!, initial, - authState.keys, + getAppStateSyncKey, ) const node: BinaryNode = { @@ -456,10 +465,10 @@ export const makeChatsSocket = (config: SocketConfig) => { } await query(node) - await authState.keys.setAppStateSyncVersion(name, state) + await authState.keys.set({ 'app-state-sync-version': { [name]: state } }) if(config.emitOwnEvents) { - const result = await decodePatches(name, [{ ...patch, version: { version: state.version }, }], initial, authState.keys.getAppStateSyncKey) + const result = await decodePatches(name, [{ ...patch, version: { version: state.version }, }], initial, getAppStateSyncKey) processSyncActions(result.newMutations) } } diff --git a/src/Socket/messages-recv.ts b/src/Socket/messages-recv.ts index 3bc863f..fa98b8e 100644 --- a/src/Socket/messages-recv.ts +++ b/src/Socket/messages-recv.ts @@ -1,12 +1,11 @@ import { SocketConfig, WAMessageStubType, ParticipantAction, Chat, GroupMetadata, WAMessageKey } from "../Types" import { decodeMessageStanza, encodeBigEndian, toNumber, downloadHistory, generateSignalPubKey, xmppPreKey, xmppSignedPreKey } from "../Utils" -import { BinaryNode, jidDecode, jidEncode, isJidStatusBroadcast, areJidsSameUser, getBinaryNodeChildren, jidNormalizedUser, getAllBinaryNodeChildren, BinaryNodeAttributes } from '../WABinary' +import { BinaryNode, jidDecode, jidEncode, isJidStatusBroadcast, areJidsSameUser, getBinaryNodeChildren, jidNormalizedUser, getAllBinaryNodeChildren, BinaryNodeAttributes, isJidGroup } from '../WABinary' import { proto } from "../../WAProto" import { KEY_BUNDLE_TYPE } from "../Defaults" import { makeChatsSocket } from "./chats" import { extractGroupMetadata } from "./groups" -import { Boom } from "@hapi/boom" const getStatusFromReceiptType = (type: string | undefined) => { if(type === 'read' || type === 'read-self') { @@ -25,7 +24,7 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { ev, authState, ws, - assertSession, + assertSessions, assertingPreKeys, sendNode, relayMessage, @@ -146,12 +145,12 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { if(keys?.length) { let newAppStateSyncKeyId = '' for(const { keyData, keyId } of keys) { - const str = Buffer.from(keyId.keyId!).toString('base64') + const strKeyId = Buffer.from(keyId.keyId!).toString('base64') - logger.info({ str }, 'injecting new app state sync key') - await authState.keys.setAppStateSyncKey(str, keyData) + logger.info({ strKeyId }, 'injecting new app state sync key') + await authState.keys.set({ 'app-state-sync-key': { [strKeyId]: keyData } }) - newAppStateSyncKeyId = str + newAppStateSyncKeyId = strKeyId } ev.emit('creds.update', { myAppStateKeyId: newAppStateSyncKeyId }) @@ -473,7 +472,11 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => { ) const participant = key.participant || key.remoteJid - await assertSession(participant, true) + await assertSessions([participant], true) + + if(isJidGroup(key.remoteJid)) { + await authState.keys.set({ 'sender-key-memory': { [key.remoteJid]: null } }) + } logger.debug({ participant }, 'forced new session for retry recp') diff --git a/src/Socket/messages-send.ts b/src/Socket/messages-send.ts index db63873..c6112fb 100644 --- a/src/Socket/messages-send.ts +++ b/src/Socket/messages-send.ts @@ -2,7 +2,7 @@ import got from "got" import { Boom } from "@hapi/boom" import { SocketConfig, MediaConnInfo, AnyMessageContent, MiscMessageGenerationOptions, WAMediaUploadFunction, MessageRelayOptions } from "../Types" -import { encodeWAMessage, generateMessageID, generateWAMessage, encryptSenderKeyMsgSignalProto, encryptSignalProto, extractDeviceJids, jidToSignalProtocolAddress, parseAndInjectE2ESession } from "../Utils" +import { encodeWAMessage, generateMessageID, generateWAMessage, encryptSenderKeyMsgSignalProto, encryptSignalProto, extractDeviceJids, jidToSignalProtocolAddress, parseAndInjectE2ESessions } from "../Utils" import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren, isJidGroup, jidDecode, jidEncode, jidNormalizedUser, S_WHATSAPP_NET, BinaryNodeAttributes, JidWithDevice, reduceBinaryNodeToDictionary } from '../WABinary' import { proto } from "../../WAProto" import { WA_DEFAULT_EPHEMERAL, DEFAULT_ORIGIN, MEDIA_PATH_MAP } from "../Defaults" @@ -189,15 +189,23 @@ export const makeMessagesSocket = (config: SocketConfig) => { return deviceResults } - const assertSession = async(jid: string, force: boolean) => { - const addr = jidToSignalProtocolAddress(jid).toString() - const session = await authState.keys.getSession(addr) - if(!session || force) { - logger.debug({ jid }, `fetching session`) - const identity: BinaryNode = { - tag: 'user', - attrs: { jid, reason: 'identity' }, + const assertSessions = async(jids: string[], force: boolean) => { + let jidsRequiringFetch: string[] = [] + if(force) { + jidsRequiringFetch = jids + } else { + const addrs = jids.map(jid => jidToSignalProtocolAddress(jid).toString()) + const sessions = await authState.keys.get('session', addrs) + for(const jid of jids) { + const signalId = jidToSignalProtocolAddress(jid).toString() + if(!sessions[signalId]) { + jidsRequiringFetch.push(jid) + } } + } + + if(jidsRequiringFetch.length) { + logger.debug({ jidsRequiringFetch }, `fetching sessions`) const result = await query({ tag: 'iq', attrs: { @@ -209,30 +217,41 @@ export const makeMessagesSocket = (config: SocketConfig) => { { tag: 'key', attrs: { }, - content: [ identity ] + content: jidsRequiringFetch.map( + jid => ({ + tag: 'user', + attrs: { jid, reason: 'identity' }, + }) + ) } ] }) - await parseAndInjectE2ESession(result, authState) + await parseAndInjectE2ESessions(result, authState) return true } return false } - const createParticipantNode = async(jid: string, bytes: Buffer) => { - await assertSession(jid, false) - - const { type, ciphertext } = await encryptSignalProto(jid, bytes, authState) - const node: BinaryNode = { - tag: 'to', - attrs: { jid }, - content: [{ - tag: 'enc', - attrs: { v: '2', type }, - content: ciphertext - }] - } - return node + const createParticipantNodes = async(jids: string[], bytes: Buffer) => { + await assertSessions(jids, false) + const nodes = await Promise.all( + jids.map( + async jid => { + const { type, ciphertext } = await encryptSignalProto(jid, bytes, authState) + const node: BinaryNode = { + tag: 'to', + attrs: { jid }, + content: [{ + tag: 'enc', + attrs: { v: '2', type }, + content: ciphertext + }] + } + return node + } + ) + ) + return nodes } const relayMessage = async( @@ -248,10 +267,11 @@ export const makeMessagesSocket = (config: SocketConfig) => { const encodedMsg = encodeWAMessage(message) const participants: BinaryNode[] = [] - let stanza: BinaryNode const destinationJid = jidEncode(user, isGroup ? 'g.us' : 's.whatsapp.net') + const binaryNodeContent: BinaryNode[] = [] + const devices: JidWithDevice[] = [] if(participant) { const { user, device } = jidDecode(participant) @@ -261,8 +281,17 @@ export const makeMessagesSocket = (config: SocketConfig) => { if(isGroup) { const { ciphertext, senderKeyDistributionMessageKey } = await encryptSenderKeyMsgSignalProto(destinationJid, encodedMsg, meId, authState) - let groupData = cachedGroupMetadata ? await cachedGroupMetadata(jid) : undefined - if(!groupData) groupData = await groupMetadata(jid) + const [groupData, senderKeyMap] = await Promise.all([ + (async() => { + let groupData = cachedGroupMetadata ? await cachedGroupMetadata(jid) : undefined + if(!groupData) groupData = await groupMetadata(jid) + return groupData + })(), + (async() => { + const result = await authState.keys.get('sender-key-memory', [jid]) + return result[jid] || { } + })() + ]) if(!participant) { const participantsList = groupData.participants.map(p => p.id) @@ -270,31 +299,31 @@ export const makeMessagesSocket = (config: SocketConfig) => { devices.push(...additionalDevices) } - const encSenderKeyMsg = encodeWAMessage({ - senderKeyDistributionMessage: { - axolotlSenderKeyDistributionMessage: senderKeyDistributionMessageKey, - groupId: destinationJid - } - }) - + const senderKeyJids: string[] = [] + // ensure a connection is established with every device for(const {user, device} of devices) { const jid = jidEncode(user, 's.whatsapp.net', device) - const participant = await createParticipantNode(jid, encSenderKeyMsg) - participants.push(participant) + if(!senderKeyMap[jid]) { + senderKeyJids.push(jid) + // store that this person has had the sender keys sent to them + senderKeyMap[jid] = true + } } + // if there are some participants with whom the session has not been established + // if there are, we re-send the senderkey + if(senderKeyJids.length) { + logger.debug({ senderKeyJids }, 'sending new sender key') - const binaryNodeContent: BinaryNode[] = [] - if( // if there are some participants with whom the session has not been established - // if there are, we overwrite the senderkey - !!participants.find((p) => ( - !!(p.content as BinaryNode[]).find(({ attrs }) => attrs.type == 'pkmsg') - )) - ) { - binaryNodeContent.push({ - tag: 'participants', - attrs: { }, - content: participants + const encSenderKeyMsg = encodeWAMessage({ + senderKeyDistributionMessage: { + axolotlSenderKeyDistributionMessage: senderKeyDistributionMessageKey, + groupId: destinationJid + } }) + + participants.push( + ...(await createParticipantNodes(senderKeyJids, encSenderKeyMsg)) + ) } binaryNodeContent.push({ @@ -303,25 +332,16 @@ export const makeMessagesSocket = (config: SocketConfig) => { content: ciphertext }) - stanza = { - tag: 'message', - attrs: { - id: msgId, - type: 'text', - to: destinationJid - }, - content: binaryNodeContent - } + await authState.keys.set({ 'sender-key-memory': { [jid]: senderKeyMap } }) } else { const { user: meUser } = jidDecode(meId) - const messageToMyself: proto.IMessage = { + const encodedMeMsg = encodeWAMessage({ deviceSentMessage: { destinationJid, message } - } - const encodedMeMsg = encodeWAMessage(messageToMyself) + }) if(!participant) { devices.push({ user }) @@ -331,47 +351,57 @@ export const makeMessagesSocket = (config: SocketConfig) => { devices.push(...additionalDevices) } + const meJids: string[] = [] + const otherJids: string[] = [] for(const { user, device } of devices) { + const jid = jidEncode(user, 's.whatsapp.net', device) const isMe = user === meUser - participants.push( - await createParticipantNode( - jidEncode(user, 's.whatsapp.net', device), - isMe ? encodedMeMsg : encodedMsg - ) - ) + if(isMe) meJids.push(jid) + else otherJids.push(jid) } - stanza = { - tag: 'message', - attrs: { - id: msgId, - type: 'text', - to: destinationJid, - ...(additionalAttributes || {}) - }, - content: [ - { - tag: 'participants', - attrs: { }, - content: participants - }, - ] - } + const [meNodes, otherNodes] = await Promise.all([ + createParticipantNodes(meJids, encodedMeMsg), + createParticipantNodes(otherJids, encodedMsg) + ]) + participants.push(...meNodes) + participants.push(...otherNodes) } - const shouldHaveIdentity = !!participants.find((p) => ( - !!(p.content as BinaryNode[]).find(({ attrs }) => attrs.type == 'pkmsg') - )) + if(participants.length) { + binaryNodeContent.push({ + tag: 'participants', + attrs: { }, + content: participants + }) + } + + const stanza: BinaryNode = { + tag: 'message', + attrs: { + id: msgId, + type: 'text', + to: destinationJid, + ...(additionalAttributes || {}) + }, + content: binaryNodeContent + } + const shouldHaveIdentity = !!participants.find( + participant => (participant.content! as BinaryNode[]).find(n => n.attrs.type === 'pkmsg') + ) + if(shouldHaveIdentity) { (stanza.content as BinaryNode[]).push({ tag: 'device-identity', attrs: { }, content: proto.ADVSignedDeviceIdentity.encode(authState.creds.account).finish() }) + + logger.debug({ jid }, 'adding device identity') } - logger.debug({ msgId }, `sending message to ${devices.length} devices`) + logger.debug({ msgId }, `sending message to ${participants.length} devices`) await sendNode(stanza) @@ -427,7 +457,7 @@ export const makeMessagesSocket = (config: SocketConfig) => { return { ...sock, - assertSession, + assertSessions, relayMessage, sendDeliveryReceipt, sendReadReceipt, diff --git a/src/Socket/socket.ts b/src/Socket/socket.ts index 3fb7657..102cdcc 100644 --- a/src/Socket/socket.ts +++ b/src/Socket/socket.ts @@ -196,10 +196,8 @@ export const makeSocket = ({ if(!creds.serverHasPreKeys) { update.serverHasPreKeys = true } - - await Promise.all( - Object.keys(newPreKeys).map(k => authState.keys.setPreKey(+k, newPreKeys[+k])) - ) + + await authState.keys.set({ 'pre-key': newPreKeys }) const preKeys = await getPreKeys(authState.keys, preKeysRange[0], preKeysRange[0] + preKeysRange[1]) await execute(preKeys) @@ -449,7 +447,7 @@ export const makeSocket = ({ const genPairQR = () => { const ref = refs.shift() if(!ref) { - end(new Boom('QR refs attempts ended', { statusCode: DisconnectReason.restartRequired })) + end(new Boom('QR refs attempts ended', { statusCode: DisconnectReason.timedOut })) return } diff --git a/src/Types/Auth.ts b/src/Types/Auth.ts index 7a3e749..c2cae16 100644 --- a/src/Types/Auth.ts +++ b/src/Types/Auth.ts @@ -43,9 +43,23 @@ export type AuthenticationCreds = SignalCreds & { lastAccountSyncTimestamp?: number } +export type SignalDataTypeMap = { + 'pre-key': KeyPair + 'session': any + 'sender-key': any + 'sender-key-memory': { [jid: string]: boolean } + 'app-state-sync-key': proto.IAppStateSyncKeyData + 'app-state-sync-version': LTHashState +} + +type SignalDataSet = { [T in keyof SignalDataTypeMap]?: { [id: string]: SignalDataTypeMap[T] | null } } + type Awaitable = T | Promise export type SignalKeyStore = { - getPreKey: (keyId: number) => Awaitable + get(type: T, ids: string[]): Awaitable<{ [id: string]: SignalDataTypeMap[T] }> + set(data: SignalDataSet): Awaitable + + /*getPreKey: (keyId: number) => Awaitable setPreKey: (keyId: number, pair: KeyPair | null) => Awaitable getSession: (sessionId: string) => Awaitable @@ -58,7 +72,7 @@ export type SignalKeyStore = { setAppStateSyncKey: (id: string, item: proto.IAppStateSyncKeyData | null) => Awaitable getAppStateSyncVersion: (name: WAPatchName) => Awaitable - setAppStateSyncVersion: (id: WAPatchName, item: LTHashState) => Awaitable + setAppStateSyncVersion: (id: WAPatchName, item: LTHashState) => Awaitable*/ } export type SignalAuthState = { diff --git a/src/Utils/auth-utils.ts b/src/Utils/auth-utils.ts index 16e6b4a..cf40474 100644 --- a/src/Utils/auth-utils.ts +++ b/src/Utils/auth-utils.ts @@ -1,86 +1,8 @@ import { randomBytes } from 'crypto' -import { proto } from '../../WAProto' -import type { SignalKeyStore, AuthenticationCreds, KeyPair, LTHashState, AuthenticationState } from "../Types" +import type { AuthenticationCreds, AuthenticationState, SignalDataTypeMap } from "../Types" import { Curve, signedKeyPair } from './crypto' import { generateRegistrationId, BufferJSON } from './generics' -export const initInMemoryKeyStore = ( - { preKeys, sessions, senderKeys, appStateSyncKeys, appStateVersions }: { - preKeys?: { [k: number]: KeyPair }, - sessions?: { [k: string]: any }, - senderKeys?: { [k: string]: any } - appStateSyncKeys?: { [k: string]: proto.IAppStateSyncKeyData }, - appStateVersions?: { [k: string]: LTHashState }, - } = { }, - save: (data: any) => void -) => { - - preKeys = preKeys || { } - sessions = sessions || { } - senderKeys = senderKeys || { } - appStateSyncKeys = appStateSyncKeys || { } - appStateVersions = appStateVersions || { } - - const keyData = { - preKeys, - sessions, - senderKeys, - appStateSyncKeys, - appStateVersions, - } - - return { - ...keyData, - getPreKey: keyId => preKeys[keyId], - setPreKey: (keyId, pair) => { - if(pair) preKeys[keyId] = pair - else delete preKeys[keyId] - - save(keyData) - }, - getSession: id => sessions[id], - setSession: (id, item) => { - if(item) sessions[id] = item - else delete sessions[id] - - save(keyData) - }, - getSenderKey: id => { - return senderKeys[id] - }, - setSenderKey: (id, item) => { - if(item) senderKeys[id] = item - else delete senderKeys[id] - - save(keyData) - }, - getAppStateSyncKey: id => { - const obj = appStateSyncKeys[id] - if(obj) { - return proto.AppStateSyncKeyData.fromObject(obj) - } - }, - setAppStateSyncKey: (id, item) => { - if(item) appStateSyncKeys[id] = item - else delete appStateSyncKeys[id] - - save(keyData) - }, - getAppStateSyncVersion: id => { - const obj = appStateVersions[id] - if(obj) { - return obj - } - }, - setAppStateSyncVersion: (id, item) => { - if(item) appStateVersions[id] = item - else delete appStateVersions[id] - - save(keyData) - } - } as SignalKeyStore -} - export const initAuthCreds = (): AuthenticationCreds => { const identityKey = Curve.generateKeyPair() return { @@ -95,12 +17,22 @@ export const initAuthCreds = (): AuthenticationCreds => { serverHasPreKeys: false } } + +const KEY_MAP: { [T in keyof SignalDataTypeMap]: string } = { + 'pre-key': 'preKeys', + 'session': 'sessions', + 'sender-key': 'senderKeys', + 'app-state-sync-key': 'appStateSyncKeys', + 'app-state-sync-version': 'appStateVersions', + 'sender-key-memory': 'senderKeyMemory' +} + /** stores the full authentication state in a single JSON file */ -export const useSingleFileAuthState = (filename: string) => { +export const useSingleFileAuthState = (filename: string): { state: AuthenticationState, saveState: () => void } => { // require fs here so that in case "fs" is not available -- the app does not crash const { readFileSync, writeFileSync, existsSync } = require('fs') - - let state: AuthenticationState = undefined + let creds: AuthenticationCreds + let keys: any = { } // save the authentication state to a file const saveState = () => { @@ -108,26 +40,48 @@ export const useSingleFileAuthState = (filename: string) => { writeFileSync( filename, // BufferJSON replacer utility saves buffers nicely - JSON.stringify(state, BufferJSON.replacer, 2) + JSON.stringify({ creds, keys }, BufferJSON.replacer, 2) ) } if(existsSync(filename)) { - const { creds, keys } = JSON.parse( + const result = JSON.parse( readFileSync(filename, { encoding: 'utf-8' }), BufferJSON.reviver ) - state = { - creds: creds, - // stores pre-keys, session & other keys in a JSON object - // we deserialize it here - keys: initInMemoryKeyStore(keys, saveState) - } + creds = result.creds + keys = result.keys } else { - const creds = initAuthCreds() - const keys = initInMemoryKeyStore({ }, saveState) - state = { creds: creds, keys: keys } + creds = initAuthCreds() + keys = { } } - return { state, saveState } + return { + state: { + creds, + keys: { + get: (type, ids) => { + const key = KEY_MAP[type] + return ids.reduce( + (dict, id) => { + const value = keys[key]?.[id] + if(value) { + dict[id] = value + } + return dict + }, { } + ) + }, + set: (data) => { + for(const _key in data) { + const key = KEY_MAP[_key as keyof SignalDataTypeMap] + keys[key] = keys[key] || { } + Object.assign(keys[key], data[_key]) + } + saveState() + } + } + }, + saveState + } } \ No newline at end of file diff --git a/src/Utils/chat-utils.ts b/src/Utils/chat-utils.ts index 908f96b..7d62295 100644 --- a/src/Utils/chat-utils.ts +++ b/src/Utils/chat-utils.ts @@ -1,12 +1,14 @@ import { Boom } from '@hapi/boom' import { aesDecrypt, hmacSign, aesEncrypt, hkdf } from "./crypto" -import { AuthenticationState, WAPatchCreate, ChatMutation, WAPatchName, LTHashState, ChatModification, SignalKeyStore } from "../Types" +import { WAPatchCreate, ChatMutation, WAPatchName, LTHashState, ChatModification, SignalKeyStore } from "../Types" import { proto } from '../../WAProto' import { LT_HASH_ANTI_TAMPERING } from './lt-hash' import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren } from '../WABinary' import { toNumber } from './generics' import { downloadContentFromMessage, } from './messages-media' +type FetchAppStateSyncKey = (keyId: string) => Promise | proto.IAppStateSyncKeyData + const mutationKeys = (keydata: Uint8Array) => { const expanded = hkdf(keydata, 160, { info: 'WhatsApp Mutation Keys' }) return { @@ -112,9 +114,9 @@ export const encodeSyncdPatch = async( { type, index, syncAction, apiVersion, operation }: WAPatchCreate, myAppStateKeyId: string, state: LTHashState, - keys: SignalKeyStore + getAppStateSyncKey: FetchAppStateSyncKey ) => { - const key = !!myAppStateKeyId ? await keys.getAppStateSyncKey(myAppStateKeyId) : undefined + const key = !!myAppStateKeyId ? await getAppStateSyncKey(myAppStateKeyId) : undefined if(!key) { throw new Boom(`myAppStateKey ("${myAppStateKeyId}") not present`, { statusCode: 404 }) } @@ -175,7 +177,7 @@ export const encodeSyncdPatch = async( export const decodeSyncdMutations = async( msgMutations: (proto.ISyncdMutation | proto.ISyncdRecord)[], initialState: LTHashState, - getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], + getAppStateSyncKey: FetchAppStateSyncKey, validateMacs: boolean ) => { const keyCache: { [_: string]: ReturnType } = { } @@ -247,7 +249,7 @@ export const decodeSyncdPatch = async( msg: proto.ISyncdPatch, name: WAPatchName, initialState: LTHashState, - getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], + getAppStateSyncKey: FetchAppStateSyncKey, validateMacs: boolean ) => { if(validateMacs) { @@ -334,7 +336,7 @@ export const downloadExternalPatch = async(blob: proto.IExternalBlobReference) = export const decodeSyncdSnapshot = async( name: WAPatchName, snapshot: proto.ISyncdSnapshot, - getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], + getAppStateSyncKey: FetchAppStateSyncKey, validateMacs: boolean = true ) => { const newState = newLTHashState() @@ -370,7 +372,7 @@ export const decodePatches = async( name: WAPatchName, syncds: proto.ISyncdPatch[], initial: LTHashState, - getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], + getAppStateSyncKey: FetchAppStateSyncKey, validateMacs: boolean = true ) => { const successfulMutations: ChatMutation[] = [] diff --git a/src/Utils/signal.ts b/src/Utils/signal.ts index 158cbf2..88533d8 100644 --- a/src/Utils/signal.ts +++ b/src/Utils/signal.ts @@ -3,7 +3,7 @@ import { encodeBigEndian } from "./generics" import { Curve } from "./crypto" import { SenderKeyDistributionMessage, GroupSessionBuilder, SenderKeyRecord, SenderKeyName, GroupCipher } from '../../WASignalGroup' import { SignalIdentity, SignalKeyStore, SignedKeyPair, KeyPair, SignalAuthState, AuthenticationCreds } from "../Types/Auth" -import { assertNodeErrorFree, BinaryNode, getBinaryNodeChild, getBinaryNodeChildBuffer, getBinaryNodeChildUInt, jidDecode, JidWithDevice } from "../WABinary" +import { assertNodeErrorFree, BinaryNode, getBinaryNodeChild, getBinaryNodeChildBuffer, getBinaryNodeChildUInt, jidDecode, JidWithDevice, getBinaryNodeChildren } from "../WABinary" import { proto } from "../../WAProto" export const generateSignalPubKey = (pubKey: Uint8Array | Buffer) => { @@ -33,13 +33,12 @@ export const createSignalIdentity = ( } } -export const getPreKeys = async({ getPreKey }: SignalKeyStore, min: number, limit: number) => { - const dict: { [id: number]: KeyPair } = { } +export const getPreKeys = async({ get }: SignalKeyStore, min: number, limit: number) => { + const idList: string[] = [] for(let id = min; id < limit;id++) { - const key = await getPreKey(id) - if(key) dict[+id] = key + idList.push(id.toString()) } - return dict + return get('pre-key', idList) } export const generateOrGetPreKeys = (creds: AuthenticationCreds, range: number) => { @@ -84,20 +83,21 @@ export const xmppPreKey = (pair: KeyPair, id: number): BinaryNode => ( ) export const signalStorage = ({ creds, keys }: SignalAuthState) => ({ - loadSession: async id => { - const sess = await keys.getSession(id) + loadSession: async (id: string) => { + const { [id]: sess } = await keys.get('session', [id]) if(sess) { return libsignal.SessionRecord.deserialize(sess) } }, storeSession: async(id, session) => { - await keys.setSession(id, session.serialize()) + await keys.set({ 'session': { [id]: session.serialize() } }) }, isTrustedIdentity: () => { return true }, - loadPreKey: async(id: number) => { - const key = await keys.getPreKey(id) + loadPreKey: async(id: number | string) => { + const keyId = id.toString() + const { [keyId]: key } = await keys.get('pre-key', [keyId]) if(key) { return { privKey: Buffer.from(key.private), @@ -105,7 +105,7 @@ export const signalStorage = ({ creds, keys }: SignalAuthState) => ({ } } }, - removePreKey: (id: number) => keys.setPreKey(id, null), + removePreKey: (id: number) => keys.set({ 'pre-key': { [id]: null } }), loadSignedPreKey: (keyId: number) => { const key = creds.signedPreKey return { @@ -113,12 +113,12 @@ export const signalStorage = ({ creds, keys }: SignalAuthState) => ({ pubKey: Buffer.from(key.keyPair.public) } }, - loadSenderKey: async(keyId) => { - const key = await keys.getSenderKey(keyId) + loadSenderKey: async(keyId: string) => { + const { [keyId]: key } = await keys.get('sender-key', [keyId]) if(key) return new SenderKeyRecord(key) }, storeSenderKey: async(keyId, key) => { - await keys.setSenderKey(keyId, key.serialize()) + await keys.set({ 'sender-key': { [keyId]: key.serialize() } }) }, getOurRegistrationId: () => ( creds.registrationId @@ -148,10 +148,10 @@ export const processSenderKeyMessage = async( const senderName = jidToSignalSenderKeyName(item.groupId, authorJid) const senderMsg = new SenderKeyDistributionMessage(null, null, null, null, item.axolotlSenderKeyDistributionMessage) - const senderKey = await auth.keys.getSenderKey(senderName) + const { [senderName]: senderKey } = await auth.keys.get('sender-key', [senderName]) if(!senderKey) { const record = new SenderKeyRecord() - await auth.keys.setSenderKey(senderName, record) + await auth.keys.set({ 'sender-key': { [senderName]: record } }) } await builder.process(senderName, senderMsg) } @@ -188,10 +188,10 @@ export const encryptSenderKeyMsgSignalProto = async(group: string, data: Uint8Ar const senderName = jidToSignalSenderKeyName(group, meId) const builder = new GroupSessionBuilder(storage) - const senderKey = await auth.keys.getSenderKey(senderName) + const { [senderName]: senderKey } = await auth.keys.get('sender-key', [senderName]) if(!senderKey) { const record = new SenderKeyRecord() - await auth.keys.setSenderKey(senderName, record) + await auth.keys.set({ 'sender-key': { [senderName]: record } }) } const senderKeyDistributionMessage = await builder.create(senderName) @@ -202,7 +202,7 @@ export const encryptSenderKeyMsgSignalProto = async(group: string, data: Uint8Ar } } -export const parseAndInjectE2ESession = async(node: BinaryNode, auth: SignalAuthState) => { +export const parseAndInjectE2ESessions = async(node: BinaryNode, auth: SignalAuthState) => { const extractKey = (key: BinaryNode) => ( key ? ({ keyId: getBinaryNodeChildUInt(key, 'id', 3), @@ -212,23 +212,30 @@ export const parseAndInjectE2ESession = async(node: BinaryNode, auth: SignalAuth signature: getBinaryNodeChildBuffer(key, 'signature'), }) : undefined ) - node = getBinaryNodeChild(getBinaryNodeChild(node, 'list'), 'user') - assertNodeErrorFree(node) - - const signedKey = getBinaryNodeChild(node, 'skey') - const key = getBinaryNodeChild(node, 'key') - const identity = getBinaryNodeChildBuffer(node, 'identity') - const jid = node.attrs.jid - const registrationId = getBinaryNodeChildUInt(node, 'registration', 4) - - const device = { - registrationId, - identityKey: generateSignalPubKey(identity), - signedPreKey: extractKey(signedKey), - preKey: extractKey(key) + const nodes = getBinaryNodeChildren(getBinaryNodeChild(node, 'list'), 'user') + for(const node of nodes) { + assertNodeErrorFree(node) } - const cipher = new libsignal.SessionBuilder(signalStorage(auth), jidToSignalProtocolAddress(jid)) - await cipher.initOutgoing(device) + await Promise.all( + nodes.map( + async node => { + const signedKey = getBinaryNodeChild(node, 'skey') + const key = getBinaryNodeChild(node, 'key') + const identity = getBinaryNodeChildBuffer(node, 'identity') + const jid = node.attrs.jid + const registrationId = getBinaryNodeChildUInt(node, 'registration', 4) + + const device = { + registrationId, + identityKey: generateSignalPubKey(identity), + signedPreKey: extractKey(signedKey), + preKey: extractKey(key) + } + const cipher = new libsignal.SessionBuilder(signalStorage(auth), jidToSignalProtocolAddress(jid)) + await cipher.initOutgoing(device) + } + ) + ) } export const extractDeviceJids = (result: BinaryNode, myJid: string, excludeZeroDevices: boolean) => {