feat: cleaner auth state management + store SK keys

!BREAKING_CHANGE
This commit is contained in:
Adhiraj Singh
2021-12-11 17:54:38 +05:30
parent 792c4bf0a4
commit 2b8256d56b
8 changed files with 264 additions and 247 deletions

View File

@@ -20,6 +20,11 @@ export const makeChatsSocket = (config: SocketConfig) => {
const mutationMutex = makeMutex() const mutationMutex = makeMutex()
const getAppStateSyncKey = async(keyId: string) => {
const { [keyId]: key } = await authState.keys.get('app-state-sync-key', [keyId])
return key
}
const interactiveQuery = async(userNodes: BinaryNode[], queryNode: BinaryNode) => { const interactiveQuery = async(userNodes: BinaryNode[], queryNode: BinaryNode) => {
const result = await query({ const result = await query({
tag: 'iq', tag: 'iq',
@@ -175,7 +180,11 @@ export const makeChatsSocket = (config: SocketConfig) => {
const states = { } as { [T in WAPatchName]: LTHashState } const states = { } as { [T in WAPatchName]: LTHashState }
for(const name of collections) { for(const name of collections) {
let state: LTHashState = fromScratch ? undefined : await authState.keys.getAppStateSyncVersion(name) let state: LTHashState
if(!fromScratch) {
const result = await authState.keys.get('app-state-sync-version', [name])
state = result[name]
}
if(!state) state = newLTHashState() if(!state) state = newLTHashState()
states[name] = state states[name] = state
@@ -213,16 +222,16 @@ export const makeChatsSocket = (config: SocketConfig) => {
const name = key as WAPatchName const name = key as WAPatchName
const { patches, snapshot } = decoded[name] const { patches, snapshot } = decoded[name]
if(snapshot) { if(snapshot) {
const newState = await decodeSyncdSnapshot(name, snapshot, authState.keys.getAppStateSyncKey) const newState = await decodeSyncdSnapshot(name, snapshot, getAppStateSyncKey)
states[name] = newState states[name] = newState
logger.info(`restored state of ${name} from snapshot to v${newState.version}`) logger.info(`restored state of ${name} from snapshot to v${newState.version}`)
} }
// only process if there are syncd patches // only process if there are syncd patches
if(patches.length) { if(patches.length) {
const { newMutations, state: newState } = await decodePatches(name, patches, states[name], authState.keys.getAppStateSyncKey, true) const { newMutations, state: newState } = await decodePatches(name, patches, states[name], getAppStateSyncKey, true)
await authState.keys.setAppStateSyncVersion(name, newState) await authState.keys.set({ 'app-state-sync-version': { [name]: newState } })
logger.info(`synced ${name} to v${newState.version}`) logger.info(`synced ${name} to v${newState.version}`)
if(newMutations.length) { if(newMutations.length) {
@@ -415,12 +424,12 @@ export const makeChatsSocket = (config: SocketConfig) => {
logger.debug({ patch: patchCreate }, 'applying app patch') logger.debug({ patch: patchCreate }, 'applying app patch')
await resyncAppState([name]) await resyncAppState([name])
const initial = await authState.keys.getAppStateSyncVersion(name) const { [name]: initial } = await authState.keys.get('app-state-sync-version', [name])
const { patch, state } = await encodeSyncdPatch( const { patch, state } = await encodeSyncdPatch(
patchCreate, patchCreate,
authState.creds.myAppStateKeyId!, authState.creds.myAppStateKeyId!,
initial, initial,
authState.keys, getAppStateSyncKey,
) )
const node: BinaryNode = { const node: BinaryNode = {
@@ -456,10 +465,10 @@ export const makeChatsSocket = (config: SocketConfig) => {
} }
await query(node) await query(node)
await authState.keys.setAppStateSyncVersion(name, state) await authState.keys.set({ 'app-state-sync-version': { [name]: state } })
if(config.emitOwnEvents) { if(config.emitOwnEvents) {
const result = await decodePatches(name, [{ ...patch, version: { version: state.version }, }], initial, authState.keys.getAppStateSyncKey) const result = await decodePatches(name, [{ ...patch, version: { version: state.version }, }], initial, getAppStateSyncKey)
processSyncActions(result.newMutations) processSyncActions(result.newMutations)
} }
} }

View File

@@ -1,12 +1,11 @@
import { SocketConfig, WAMessageStubType, ParticipantAction, Chat, GroupMetadata, WAMessageKey } from "../Types" import { SocketConfig, WAMessageStubType, ParticipantAction, Chat, GroupMetadata, WAMessageKey } from "../Types"
import { decodeMessageStanza, encodeBigEndian, toNumber, downloadHistory, generateSignalPubKey, xmppPreKey, xmppSignedPreKey } from "../Utils" import { decodeMessageStanza, encodeBigEndian, toNumber, downloadHistory, generateSignalPubKey, xmppPreKey, xmppSignedPreKey } from "../Utils"
import { BinaryNode, jidDecode, jidEncode, isJidStatusBroadcast, areJidsSameUser, getBinaryNodeChildren, jidNormalizedUser, getAllBinaryNodeChildren, BinaryNodeAttributes } from '../WABinary' import { BinaryNode, jidDecode, jidEncode, isJidStatusBroadcast, areJidsSameUser, getBinaryNodeChildren, jidNormalizedUser, getAllBinaryNodeChildren, BinaryNodeAttributes, isJidGroup } from '../WABinary'
import { proto } from "../../WAProto" import { proto } from "../../WAProto"
import { KEY_BUNDLE_TYPE } from "../Defaults" import { KEY_BUNDLE_TYPE } from "../Defaults"
import { makeChatsSocket } from "./chats" import { makeChatsSocket } from "./chats"
import { extractGroupMetadata } from "./groups" import { extractGroupMetadata } from "./groups"
import { Boom } from "@hapi/boom"
const getStatusFromReceiptType = (type: string | undefined) => { const getStatusFromReceiptType = (type: string | undefined) => {
if(type === 'read' || type === 'read-self') { if(type === 'read' || type === 'read-self') {
@@ -25,7 +24,7 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
ev, ev,
authState, authState,
ws, ws,
assertSession, assertSessions,
assertingPreKeys, assertingPreKeys,
sendNode, sendNode,
relayMessage, relayMessage,
@@ -146,12 +145,12 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
if(keys?.length) { if(keys?.length) {
let newAppStateSyncKeyId = '' let newAppStateSyncKeyId = ''
for(const { keyData, keyId } of keys) { for(const { keyData, keyId } of keys) {
const str = Buffer.from(keyId.keyId!).toString('base64') const strKeyId = Buffer.from(keyId.keyId!).toString('base64')
logger.info({ str }, 'injecting new app state sync key') logger.info({ strKeyId }, 'injecting new app state sync key')
await authState.keys.setAppStateSyncKey(str, keyData) await authState.keys.set({ 'app-state-sync-key': { [strKeyId]: keyData } })
newAppStateSyncKeyId = str newAppStateSyncKeyId = strKeyId
} }
ev.emit('creds.update', { myAppStateKeyId: newAppStateSyncKeyId }) ev.emit('creds.update', { myAppStateKeyId: newAppStateSyncKeyId })
@@ -473,7 +472,11 @@ export const makeMessagesRecvSocket = (config: SocketConfig) => {
) )
const participant = key.participant || key.remoteJid const participant = key.participant || key.remoteJid
await assertSession(participant, true) await assertSessions([participant], true)
if(isJidGroup(key.remoteJid)) {
await authState.keys.set({ 'sender-key-memory': { [key.remoteJid]: null } })
}
logger.debug({ participant }, 'forced new session for retry recp') logger.debug({ participant }, 'forced new session for retry recp')

View File

@@ -2,7 +2,7 @@
import got from "got" import got from "got"
import { Boom } from "@hapi/boom" import { Boom } from "@hapi/boom"
import { SocketConfig, MediaConnInfo, AnyMessageContent, MiscMessageGenerationOptions, WAMediaUploadFunction, MessageRelayOptions } from "../Types" import { SocketConfig, MediaConnInfo, AnyMessageContent, MiscMessageGenerationOptions, WAMediaUploadFunction, MessageRelayOptions } from "../Types"
import { encodeWAMessage, generateMessageID, generateWAMessage, encryptSenderKeyMsgSignalProto, encryptSignalProto, extractDeviceJids, jidToSignalProtocolAddress, parseAndInjectE2ESession } from "../Utils" import { encodeWAMessage, generateMessageID, generateWAMessage, encryptSenderKeyMsgSignalProto, encryptSignalProto, extractDeviceJids, jidToSignalProtocolAddress, parseAndInjectE2ESessions } from "../Utils"
import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren, isJidGroup, jidDecode, jidEncode, jidNormalizedUser, S_WHATSAPP_NET, BinaryNodeAttributes, JidWithDevice, reduceBinaryNodeToDictionary } from '../WABinary' import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren, isJidGroup, jidDecode, jidEncode, jidNormalizedUser, S_WHATSAPP_NET, BinaryNodeAttributes, JidWithDevice, reduceBinaryNodeToDictionary } from '../WABinary'
import { proto } from "../../WAProto" import { proto } from "../../WAProto"
import { WA_DEFAULT_EPHEMERAL, DEFAULT_ORIGIN, MEDIA_PATH_MAP } from "../Defaults" import { WA_DEFAULT_EPHEMERAL, DEFAULT_ORIGIN, MEDIA_PATH_MAP } from "../Defaults"
@@ -189,15 +189,23 @@ export const makeMessagesSocket = (config: SocketConfig) => {
return deviceResults return deviceResults
} }
const assertSession = async(jid: string, force: boolean) => { const assertSessions = async(jids: string[], force: boolean) => {
const addr = jidToSignalProtocolAddress(jid).toString() let jidsRequiringFetch: string[] = []
const session = await authState.keys.getSession(addr) if(force) {
if(!session || force) { jidsRequiringFetch = jids
logger.debug({ jid }, `fetching session`) } else {
const identity: BinaryNode = { const addrs = jids.map(jid => jidToSignalProtocolAddress(jid).toString())
tag: 'user', const sessions = await authState.keys.get('session', addrs)
attrs: { jid, reason: 'identity' }, for(const jid of jids) {
const signalId = jidToSignalProtocolAddress(jid).toString()
if(!sessions[signalId]) {
jidsRequiringFetch.push(jid)
}
} }
}
if(jidsRequiringFetch.length) {
logger.debug({ jidsRequiringFetch }, `fetching sessions`)
const result = await query({ const result = await query({
tag: 'iq', tag: 'iq',
attrs: { attrs: {
@@ -209,30 +217,41 @@ export const makeMessagesSocket = (config: SocketConfig) => {
{ {
tag: 'key', tag: 'key',
attrs: { }, attrs: { },
content: [ identity ] content: jidsRequiringFetch.map(
jid => ({
tag: 'user',
attrs: { jid, reason: 'identity' },
})
)
} }
] ]
}) })
await parseAndInjectE2ESession(result, authState) await parseAndInjectE2ESessions(result, authState)
return true return true
} }
return false return false
} }
const createParticipantNode = async(jid: string, bytes: Buffer) => { const createParticipantNodes = async(jids: string[], bytes: Buffer) => {
await assertSession(jid, false) await assertSessions(jids, false)
const nodes = await Promise.all(
const { type, ciphertext } = await encryptSignalProto(jid, bytes, authState) jids.map(
const node: BinaryNode = { async jid => {
tag: 'to', const { type, ciphertext } = await encryptSignalProto(jid, bytes, authState)
attrs: { jid }, const node: BinaryNode = {
content: [{ tag: 'to',
tag: 'enc', attrs: { jid },
attrs: { v: '2', type }, content: [{
content: ciphertext tag: 'enc',
}] attrs: { v: '2', type },
} content: ciphertext
return node }]
}
return node
}
)
)
return nodes
} }
const relayMessage = async( const relayMessage = async(
@@ -248,10 +267,11 @@ export const makeMessagesSocket = (config: SocketConfig) => {
const encodedMsg = encodeWAMessage(message) const encodedMsg = encodeWAMessage(message)
const participants: BinaryNode[] = [] const participants: BinaryNode[] = []
let stanza: BinaryNode
const destinationJid = jidEncode(user, isGroup ? 'g.us' : 's.whatsapp.net') const destinationJid = jidEncode(user, isGroup ? 'g.us' : 's.whatsapp.net')
const binaryNodeContent: BinaryNode[] = []
const devices: JidWithDevice[] = [] const devices: JidWithDevice[] = []
if(participant) { if(participant) {
const { user, device } = jidDecode(participant) const { user, device } = jidDecode(participant)
@@ -261,8 +281,17 @@ export const makeMessagesSocket = (config: SocketConfig) => {
if(isGroup) { if(isGroup) {
const { ciphertext, senderKeyDistributionMessageKey } = await encryptSenderKeyMsgSignalProto(destinationJid, encodedMsg, meId, authState) const { ciphertext, senderKeyDistributionMessageKey } = await encryptSenderKeyMsgSignalProto(destinationJid, encodedMsg, meId, authState)
let groupData = cachedGroupMetadata ? await cachedGroupMetadata(jid) : undefined const [groupData, senderKeyMap] = await Promise.all([
if(!groupData) groupData = await groupMetadata(jid) (async() => {
let groupData = cachedGroupMetadata ? await cachedGroupMetadata(jid) : undefined
if(!groupData) groupData = await groupMetadata(jid)
return groupData
})(),
(async() => {
const result = await authState.keys.get('sender-key-memory', [jid])
return result[jid] || { }
})()
])
if(!participant) { if(!participant) {
const participantsList = groupData.participants.map(p => p.id) const participantsList = groupData.participants.map(p => p.id)
@@ -270,31 +299,31 @@ export const makeMessagesSocket = (config: SocketConfig) => {
devices.push(...additionalDevices) devices.push(...additionalDevices)
} }
const encSenderKeyMsg = encodeWAMessage({ const senderKeyJids: string[] = []
senderKeyDistributionMessage: { // ensure a connection is established with every device
axolotlSenderKeyDistributionMessage: senderKeyDistributionMessageKey,
groupId: destinationJid
}
})
for(const {user, device} of devices) { for(const {user, device} of devices) {
const jid = jidEncode(user, 's.whatsapp.net', device) const jid = jidEncode(user, 's.whatsapp.net', device)
const participant = await createParticipantNode(jid, encSenderKeyMsg) if(!senderKeyMap[jid]) {
participants.push(participant) senderKeyJids.push(jid)
// store that this person has had the sender keys sent to them
senderKeyMap[jid] = true
}
} }
// if there are some participants with whom the session has not been established
// if there are, we re-send the senderkey
if(senderKeyJids.length) {
logger.debug({ senderKeyJids }, 'sending new sender key')
const binaryNodeContent: BinaryNode[] = [] const encSenderKeyMsg = encodeWAMessage({
if( // if there are some participants with whom the session has not been established senderKeyDistributionMessage: {
// if there are, we overwrite the senderkey axolotlSenderKeyDistributionMessage: senderKeyDistributionMessageKey,
!!participants.find((p) => ( groupId: destinationJid
!!(p.content as BinaryNode[]).find(({ attrs }) => attrs.type == 'pkmsg') }
))
) {
binaryNodeContent.push({
tag: 'participants',
attrs: { },
content: participants
}) })
participants.push(
...(await createParticipantNodes(senderKeyJids, encSenderKeyMsg))
)
} }
binaryNodeContent.push({ binaryNodeContent.push({
@@ -303,25 +332,16 @@ export const makeMessagesSocket = (config: SocketConfig) => {
content: ciphertext content: ciphertext
}) })
stanza = { await authState.keys.set({ 'sender-key-memory': { [jid]: senderKeyMap } })
tag: 'message',
attrs: {
id: msgId,
type: 'text',
to: destinationJid
},
content: binaryNodeContent
}
} else { } else {
const { user: meUser } = jidDecode(meId) const { user: meUser } = jidDecode(meId)
const messageToMyself: proto.IMessage = { const encodedMeMsg = encodeWAMessage({
deviceSentMessage: { deviceSentMessage: {
destinationJid, destinationJid,
message message
} }
} })
const encodedMeMsg = encodeWAMessage(messageToMyself)
if(!participant) { if(!participant) {
devices.push({ user }) devices.push({ user })
@@ -331,47 +351,57 @@ export const makeMessagesSocket = (config: SocketConfig) => {
devices.push(...additionalDevices) devices.push(...additionalDevices)
} }
const meJids: string[] = []
const otherJids: string[] = []
for(const { user, device } of devices) { for(const { user, device } of devices) {
const jid = jidEncode(user, 's.whatsapp.net', device)
const isMe = user === meUser const isMe = user === meUser
participants.push( if(isMe) meJids.push(jid)
await createParticipantNode( else otherJids.push(jid)
jidEncode(user, 's.whatsapp.net', device),
isMe ? encodedMeMsg : encodedMsg
)
)
} }
stanza = { const [meNodes, otherNodes] = await Promise.all([
tag: 'message', createParticipantNodes(meJids, encodedMeMsg),
attrs: { createParticipantNodes(otherJids, encodedMsg)
id: msgId, ])
type: 'text', participants.push(...meNodes)
to: destinationJid, participants.push(...otherNodes)
...(additionalAttributes || {})
},
content: [
{
tag: 'participants',
attrs: { },
content: participants
},
]
}
} }
const shouldHaveIdentity = !!participants.find((p) => ( if(participants.length) {
!!(p.content as BinaryNode[]).find(({ attrs }) => attrs.type == 'pkmsg') binaryNodeContent.push({
)) tag: 'participants',
attrs: { },
content: participants
})
}
const stanza: BinaryNode = {
tag: 'message',
attrs: {
id: msgId,
type: 'text',
to: destinationJid,
...(additionalAttributes || {})
},
content: binaryNodeContent
}
const shouldHaveIdentity = !!participants.find(
participant => (participant.content! as BinaryNode[]).find(n => n.attrs.type === 'pkmsg')
)
if(shouldHaveIdentity) { if(shouldHaveIdentity) {
(stanza.content as BinaryNode[]).push({ (stanza.content as BinaryNode[]).push({
tag: 'device-identity', tag: 'device-identity',
attrs: { }, attrs: { },
content: proto.ADVSignedDeviceIdentity.encode(authState.creds.account).finish() content: proto.ADVSignedDeviceIdentity.encode(authState.creds.account).finish()
}) })
logger.debug({ jid }, 'adding device identity')
} }
logger.debug({ msgId }, `sending message to ${devices.length} devices`) logger.debug({ msgId }, `sending message to ${participants.length} devices`)
await sendNode(stanza) await sendNode(stanza)
@@ -427,7 +457,7 @@ export const makeMessagesSocket = (config: SocketConfig) => {
return { return {
...sock, ...sock,
assertSession, assertSessions,
relayMessage, relayMessage,
sendDeliveryReceipt, sendDeliveryReceipt,
sendReadReceipt, sendReadReceipt,

View File

@@ -196,10 +196,8 @@ export const makeSocket = ({
if(!creds.serverHasPreKeys) { if(!creds.serverHasPreKeys) {
update.serverHasPreKeys = true update.serverHasPreKeys = true
} }
await Promise.all( await authState.keys.set({ 'pre-key': newPreKeys })
Object.keys(newPreKeys).map(k => authState.keys.setPreKey(+k, newPreKeys[+k]))
)
const preKeys = await getPreKeys(authState.keys, preKeysRange[0], preKeysRange[0] + preKeysRange[1]) const preKeys = await getPreKeys(authState.keys, preKeysRange[0], preKeysRange[0] + preKeysRange[1])
await execute(preKeys) await execute(preKeys)
@@ -449,7 +447,7 @@ export const makeSocket = ({
const genPairQR = () => { const genPairQR = () => {
const ref = refs.shift() const ref = refs.shift()
if(!ref) { if(!ref) {
end(new Boom('QR refs attempts ended', { statusCode: DisconnectReason.restartRequired })) end(new Boom('QR refs attempts ended', { statusCode: DisconnectReason.timedOut }))
return return
} }

View File

@@ -43,9 +43,23 @@ export type AuthenticationCreds = SignalCreds & {
lastAccountSyncTimestamp?: number lastAccountSyncTimestamp?: number
} }
export type SignalDataTypeMap = {
'pre-key': KeyPair
'session': any
'sender-key': any
'sender-key-memory': { [jid: string]: boolean }
'app-state-sync-key': proto.IAppStateSyncKeyData
'app-state-sync-version': LTHashState
}
type SignalDataSet = { [T in keyof SignalDataTypeMap]?: { [id: string]: SignalDataTypeMap[T] | null } }
type Awaitable<T> = T | Promise<T> type Awaitable<T> = T | Promise<T>
export type SignalKeyStore = { export type SignalKeyStore = {
getPreKey: (keyId: number) => Awaitable<KeyPair> get<T extends keyof SignalDataTypeMap>(type: T, ids: string[]): Awaitable<{ [id: string]: SignalDataTypeMap[T] }>
set(data: SignalDataSet): Awaitable<void>
/*getPreKey: (keyId: number) => Awaitable<KeyPair>
setPreKey: (keyId: number, pair: KeyPair | null) => Awaitable<void> setPreKey: (keyId: number, pair: KeyPair | null) => Awaitable<void>
getSession: (sessionId: string) => Awaitable<any> getSession: (sessionId: string) => Awaitable<any>
@@ -58,7 +72,7 @@ export type SignalKeyStore = {
setAppStateSyncKey: (id: string, item: proto.IAppStateSyncKeyData | null) => Awaitable<void> setAppStateSyncKey: (id: string, item: proto.IAppStateSyncKeyData | null) => Awaitable<void>
getAppStateSyncVersion: (name: WAPatchName) => Awaitable<LTHashState> getAppStateSyncVersion: (name: WAPatchName) => Awaitable<LTHashState>
setAppStateSyncVersion: (id: WAPatchName, item: LTHashState) => Awaitable<void> setAppStateSyncVersion: (id: WAPatchName, item: LTHashState) => Awaitable<void>*/
} }
export type SignalAuthState = { export type SignalAuthState = {

View File

@@ -1,86 +1,8 @@
import { randomBytes } from 'crypto' import { randomBytes } from 'crypto'
import { proto } from '../../WAProto' import type { AuthenticationCreds, AuthenticationState, SignalDataTypeMap } from "../Types"
import type { SignalKeyStore, AuthenticationCreds, KeyPair, LTHashState, AuthenticationState } from "../Types"
import { Curve, signedKeyPair } from './crypto' import { Curve, signedKeyPair } from './crypto'
import { generateRegistrationId, BufferJSON } from './generics' import { generateRegistrationId, BufferJSON } from './generics'
export const initInMemoryKeyStore = (
{ preKeys, sessions, senderKeys, appStateSyncKeys, appStateVersions }: {
preKeys?: { [k: number]: KeyPair },
sessions?: { [k: string]: any },
senderKeys?: { [k: string]: any }
appStateSyncKeys?: { [k: string]: proto.IAppStateSyncKeyData },
appStateVersions?: { [k: string]: LTHashState },
} = { },
save: (data: any) => void
) => {
preKeys = preKeys || { }
sessions = sessions || { }
senderKeys = senderKeys || { }
appStateSyncKeys = appStateSyncKeys || { }
appStateVersions = appStateVersions || { }
const keyData = {
preKeys,
sessions,
senderKeys,
appStateSyncKeys,
appStateVersions,
}
return {
...keyData,
getPreKey: keyId => preKeys[keyId],
setPreKey: (keyId, pair) => {
if(pair) preKeys[keyId] = pair
else delete preKeys[keyId]
save(keyData)
},
getSession: id => sessions[id],
setSession: (id, item) => {
if(item) sessions[id] = item
else delete sessions[id]
save(keyData)
},
getSenderKey: id => {
return senderKeys[id]
},
setSenderKey: (id, item) => {
if(item) senderKeys[id] = item
else delete senderKeys[id]
save(keyData)
},
getAppStateSyncKey: id => {
const obj = appStateSyncKeys[id]
if(obj) {
return proto.AppStateSyncKeyData.fromObject(obj)
}
},
setAppStateSyncKey: (id, item) => {
if(item) appStateSyncKeys[id] = item
else delete appStateSyncKeys[id]
save(keyData)
},
getAppStateSyncVersion: id => {
const obj = appStateVersions[id]
if(obj) {
return obj
}
},
setAppStateSyncVersion: (id, item) => {
if(item) appStateVersions[id] = item
else delete appStateVersions[id]
save(keyData)
}
} as SignalKeyStore
}
export const initAuthCreds = (): AuthenticationCreds => { export const initAuthCreds = (): AuthenticationCreds => {
const identityKey = Curve.generateKeyPair() const identityKey = Curve.generateKeyPair()
return { return {
@@ -95,12 +17,22 @@ export const initAuthCreds = (): AuthenticationCreds => {
serverHasPreKeys: false serverHasPreKeys: false
} }
} }
const KEY_MAP: { [T in keyof SignalDataTypeMap]: string } = {
'pre-key': 'preKeys',
'session': 'sessions',
'sender-key': 'senderKeys',
'app-state-sync-key': 'appStateSyncKeys',
'app-state-sync-version': 'appStateVersions',
'sender-key-memory': 'senderKeyMemory'
}
/** stores the full authentication state in a single JSON file */ /** stores the full authentication state in a single JSON file */
export const useSingleFileAuthState = (filename: string) => { export const useSingleFileAuthState = (filename: string): { state: AuthenticationState, saveState: () => void } => {
// require fs here so that in case "fs" is not available -- the app does not crash // require fs here so that in case "fs" is not available -- the app does not crash
const { readFileSync, writeFileSync, existsSync } = require('fs') const { readFileSync, writeFileSync, existsSync } = require('fs')
let creds: AuthenticationCreds
let state: AuthenticationState = undefined let keys: any = { }
// save the authentication state to a file // save the authentication state to a file
const saveState = () => { const saveState = () => {
@@ -108,26 +40,48 @@ export const useSingleFileAuthState = (filename: string) => {
writeFileSync( writeFileSync(
filename, filename,
// BufferJSON replacer utility saves buffers nicely // BufferJSON replacer utility saves buffers nicely
JSON.stringify(state, BufferJSON.replacer, 2) JSON.stringify({ creds, keys }, BufferJSON.replacer, 2)
) )
} }
if(existsSync(filename)) { if(existsSync(filename)) {
const { creds, keys } = JSON.parse( const result = JSON.parse(
readFileSync(filename, { encoding: 'utf-8' }), readFileSync(filename, { encoding: 'utf-8' }),
BufferJSON.reviver BufferJSON.reviver
) )
state = { creds = result.creds
creds: creds, keys = result.keys
// stores pre-keys, session & other keys in a JSON object
// we deserialize it here
keys: initInMemoryKeyStore(keys, saveState)
}
} else { } else {
const creds = initAuthCreds() creds = initAuthCreds()
const keys = initInMemoryKeyStore({ }, saveState) keys = { }
state = { creds: creds, keys: keys }
} }
return { state, saveState } return {
state: {
creds,
keys: {
get: (type, ids) => {
const key = KEY_MAP[type]
return ids.reduce(
(dict, id) => {
const value = keys[key]?.[id]
if(value) {
dict[id] = value
}
return dict
}, { }
)
},
set: (data) => {
for(const _key in data) {
const key = KEY_MAP[_key as keyof SignalDataTypeMap]
keys[key] = keys[key] || { }
Object.assign(keys[key], data[_key])
}
saveState()
}
}
},
saveState
}
} }

View File

@@ -1,12 +1,14 @@
import { Boom } from '@hapi/boom' import { Boom } from '@hapi/boom'
import { aesDecrypt, hmacSign, aesEncrypt, hkdf } from "./crypto" import { aesDecrypt, hmacSign, aesEncrypt, hkdf } from "./crypto"
import { AuthenticationState, WAPatchCreate, ChatMutation, WAPatchName, LTHashState, ChatModification, SignalKeyStore } from "../Types" import { WAPatchCreate, ChatMutation, WAPatchName, LTHashState, ChatModification, SignalKeyStore } from "../Types"
import { proto } from '../../WAProto' import { proto } from '../../WAProto'
import { LT_HASH_ANTI_TAMPERING } from './lt-hash' import { LT_HASH_ANTI_TAMPERING } from './lt-hash'
import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren } from '../WABinary' import { BinaryNode, getBinaryNodeChild, getBinaryNodeChildren } from '../WABinary'
import { toNumber } from './generics' import { toNumber } from './generics'
import { downloadContentFromMessage, } from './messages-media' import { downloadContentFromMessage, } from './messages-media'
type FetchAppStateSyncKey = (keyId: string) => Promise<proto.IAppStateSyncKeyData> | proto.IAppStateSyncKeyData
const mutationKeys = (keydata: Uint8Array) => { const mutationKeys = (keydata: Uint8Array) => {
const expanded = hkdf(keydata, 160, { info: 'WhatsApp Mutation Keys' }) const expanded = hkdf(keydata, 160, { info: 'WhatsApp Mutation Keys' })
return { return {
@@ -112,9 +114,9 @@ export const encodeSyncdPatch = async(
{ type, index, syncAction, apiVersion, operation }: WAPatchCreate, { type, index, syncAction, apiVersion, operation }: WAPatchCreate,
myAppStateKeyId: string, myAppStateKeyId: string,
state: LTHashState, state: LTHashState,
keys: SignalKeyStore getAppStateSyncKey: FetchAppStateSyncKey
) => { ) => {
const key = !!myAppStateKeyId ? await keys.getAppStateSyncKey(myAppStateKeyId) : undefined const key = !!myAppStateKeyId ? await getAppStateSyncKey(myAppStateKeyId) : undefined
if(!key) { if(!key) {
throw new Boom(`myAppStateKey ("${myAppStateKeyId}") not present`, { statusCode: 404 }) throw new Boom(`myAppStateKey ("${myAppStateKeyId}") not present`, { statusCode: 404 })
} }
@@ -175,7 +177,7 @@ export const encodeSyncdPatch = async(
export const decodeSyncdMutations = async( export const decodeSyncdMutations = async(
msgMutations: (proto.ISyncdMutation | proto.ISyncdRecord)[], msgMutations: (proto.ISyncdMutation | proto.ISyncdRecord)[],
initialState: LTHashState, initialState: LTHashState,
getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], getAppStateSyncKey: FetchAppStateSyncKey,
validateMacs: boolean validateMacs: boolean
) => { ) => {
const keyCache: { [_: string]: ReturnType<typeof mutationKeys> } = { } const keyCache: { [_: string]: ReturnType<typeof mutationKeys> } = { }
@@ -247,7 +249,7 @@ export const decodeSyncdPatch = async(
msg: proto.ISyncdPatch, msg: proto.ISyncdPatch,
name: WAPatchName, name: WAPatchName,
initialState: LTHashState, initialState: LTHashState,
getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], getAppStateSyncKey: FetchAppStateSyncKey,
validateMacs: boolean validateMacs: boolean
) => { ) => {
if(validateMacs) { if(validateMacs) {
@@ -334,7 +336,7 @@ export const downloadExternalPatch = async(blob: proto.IExternalBlobReference) =
export const decodeSyncdSnapshot = async( export const decodeSyncdSnapshot = async(
name: WAPatchName, name: WAPatchName,
snapshot: proto.ISyncdSnapshot, snapshot: proto.ISyncdSnapshot,
getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], getAppStateSyncKey: FetchAppStateSyncKey,
validateMacs: boolean = true validateMacs: boolean = true
) => { ) => {
const newState = newLTHashState() const newState = newLTHashState()
@@ -370,7 +372,7 @@ export const decodePatches = async(
name: WAPatchName, name: WAPatchName,
syncds: proto.ISyncdPatch[], syncds: proto.ISyncdPatch[],
initial: LTHashState, initial: LTHashState,
getAppStateSyncKey: SignalKeyStore['getAppStateSyncKey'], getAppStateSyncKey: FetchAppStateSyncKey,
validateMacs: boolean = true validateMacs: boolean = true
) => { ) => {
const successfulMutations: ChatMutation[] = [] const successfulMutations: ChatMutation[] = []

View File

@@ -3,7 +3,7 @@ import { encodeBigEndian } from "./generics"
import { Curve } from "./crypto" import { Curve } from "./crypto"
import { SenderKeyDistributionMessage, GroupSessionBuilder, SenderKeyRecord, SenderKeyName, GroupCipher } from '../../WASignalGroup' import { SenderKeyDistributionMessage, GroupSessionBuilder, SenderKeyRecord, SenderKeyName, GroupCipher } from '../../WASignalGroup'
import { SignalIdentity, SignalKeyStore, SignedKeyPair, KeyPair, SignalAuthState, AuthenticationCreds } from "../Types/Auth" import { SignalIdentity, SignalKeyStore, SignedKeyPair, KeyPair, SignalAuthState, AuthenticationCreds } from "../Types/Auth"
import { assertNodeErrorFree, BinaryNode, getBinaryNodeChild, getBinaryNodeChildBuffer, getBinaryNodeChildUInt, jidDecode, JidWithDevice } from "../WABinary" import { assertNodeErrorFree, BinaryNode, getBinaryNodeChild, getBinaryNodeChildBuffer, getBinaryNodeChildUInt, jidDecode, JidWithDevice, getBinaryNodeChildren } from "../WABinary"
import { proto } from "../../WAProto" import { proto } from "../../WAProto"
export const generateSignalPubKey = (pubKey: Uint8Array | Buffer) => { export const generateSignalPubKey = (pubKey: Uint8Array | Buffer) => {
@@ -33,13 +33,12 @@ export const createSignalIdentity = (
} }
} }
export const getPreKeys = async({ getPreKey }: SignalKeyStore, min: number, limit: number) => { export const getPreKeys = async({ get }: SignalKeyStore, min: number, limit: number) => {
const dict: { [id: number]: KeyPair } = { } const idList: string[] = []
for(let id = min; id < limit;id++) { for(let id = min; id < limit;id++) {
const key = await getPreKey(id) idList.push(id.toString())
if(key) dict[+id] = key
} }
return dict return get('pre-key', idList)
} }
export const generateOrGetPreKeys = (creds: AuthenticationCreds, range: number) => { export const generateOrGetPreKeys = (creds: AuthenticationCreds, range: number) => {
@@ -84,20 +83,21 @@ export const xmppPreKey = (pair: KeyPair, id: number): BinaryNode => (
) )
export const signalStorage = ({ creds, keys }: SignalAuthState) => ({ export const signalStorage = ({ creds, keys }: SignalAuthState) => ({
loadSession: async id => { loadSession: async (id: string) => {
const sess = await keys.getSession(id) const { [id]: sess } = await keys.get('session', [id])
if(sess) { if(sess) {
return libsignal.SessionRecord.deserialize(sess) return libsignal.SessionRecord.deserialize(sess)
} }
}, },
storeSession: async(id, session) => { storeSession: async(id, session) => {
await keys.setSession(id, session.serialize()) await keys.set({ 'session': { [id]: session.serialize() } })
}, },
isTrustedIdentity: () => { isTrustedIdentity: () => {
return true return true
}, },
loadPreKey: async(id: number) => { loadPreKey: async(id: number | string) => {
const key = await keys.getPreKey(id) const keyId = id.toString()
const { [keyId]: key } = await keys.get('pre-key', [keyId])
if(key) { if(key) {
return { return {
privKey: Buffer.from(key.private), privKey: Buffer.from(key.private),
@@ -105,7 +105,7 @@ export const signalStorage = ({ creds, keys }: SignalAuthState) => ({
} }
} }
}, },
removePreKey: (id: number) => keys.setPreKey(id, null), removePreKey: (id: number) => keys.set({ 'pre-key': { [id]: null } }),
loadSignedPreKey: (keyId: number) => { loadSignedPreKey: (keyId: number) => {
const key = creds.signedPreKey const key = creds.signedPreKey
return { return {
@@ -113,12 +113,12 @@ export const signalStorage = ({ creds, keys }: SignalAuthState) => ({
pubKey: Buffer.from(key.keyPair.public) pubKey: Buffer.from(key.keyPair.public)
} }
}, },
loadSenderKey: async(keyId) => { loadSenderKey: async(keyId: string) => {
const key = await keys.getSenderKey(keyId) const { [keyId]: key } = await keys.get('sender-key', [keyId])
if(key) return new SenderKeyRecord(key) if(key) return new SenderKeyRecord(key)
}, },
storeSenderKey: async(keyId, key) => { storeSenderKey: async(keyId, key) => {
await keys.setSenderKey(keyId, key.serialize()) await keys.set({ 'sender-key': { [keyId]: key.serialize() } })
}, },
getOurRegistrationId: () => ( getOurRegistrationId: () => (
creds.registrationId creds.registrationId
@@ -148,10 +148,10 @@ export const processSenderKeyMessage = async(
const senderName = jidToSignalSenderKeyName(item.groupId, authorJid) const senderName = jidToSignalSenderKeyName(item.groupId, authorJid)
const senderMsg = new SenderKeyDistributionMessage(null, null, null, null, item.axolotlSenderKeyDistributionMessage) const senderMsg = new SenderKeyDistributionMessage(null, null, null, null, item.axolotlSenderKeyDistributionMessage)
const senderKey = await auth.keys.getSenderKey(senderName) const { [senderName]: senderKey } = await auth.keys.get('sender-key', [senderName])
if(!senderKey) { if(!senderKey) {
const record = new SenderKeyRecord() const record = new SenderKeyRecord()
await auth.keys.setSenderKey(senderName, record) await auth.keys.set({ 'sender-key': { [senderName]: record } })
} }
await builder.process(senderName, senderMsg) await builder.process(senderName, senderMsg)
} }
@@ -188,10 +188,10 @@ export const encryptSenderKeyMsgSignalProto = async(group: string, data: Uint8Ar
const senderName = jidToSignalSenderKeyName(group, meId) const senderName = jidToSignalSenderKeyName(group, meId)
const builder = new GroupSessionBuilder(storage) const builder = new GroupSessionBuilder(storage)
const senderKey = await auth.keys.getSenderKey(senderName) const { [senderName]: senderKey } = await auth.keys.get('sender-key', [senderName])
if(!senderKey) { if(!senderKey) {
const record = new SenderKeyRecord() const record = new SenderKeyRecord()
await auth.keys.setSenderKey(senderName, record) await auth.keys.set({ 'sender-key': { [senderName]: record } })
} }
const senderKeyDistributionMessage = await builder.create(senderName) const senderKeyDistributionMessage = await builder.create(senderName)
@@ -202,7 +202,7 @@ export const encryptSenderKeyMsgSignalProto = async(group: string, data: Uint8Ar
} }
} }
export const parseAndInjectE2ESession = async(node: BinaryNode, auth: SignalAuthState) => { export const parseAndInjectE2ESessions = async(node: BinaryNode, auth: SignalAuthState) => {
const extractKey = (key: BinaryNode) => ( const extractKey = (key: BinaryNode) => (
key ? ({ key ? ({
keyId: getBinaryNodeChildUInt(key, 'id', 3), keyId: getBinaryNodeChildUInt(key, 'id', 3),
@@ -212,23 +212,30 @@ export const parseAndInjectE2ESession = async(node: BinaryNode, auth: SignalAuth
signature: getBinaryNodeChildBuffer(key, 'signature'), signature: getBinaryNodeChildBuffer(key, 'signature'),
}) : undefined }) : undefined
) )
node = getBinaryNodeChild(getBinaryNodeChild(node, 'list'), 'user') const nodes = getBinaryNodeChildren(getBinaryNodeChild(node, 'list'), 'user')
assertNodeErrorFree(node) for(const node of nodes) {
assertNodeErrorFree(node)
const signedKey = getBinaryNodeChild(node, 'skey')
const key = getBinaryNodeChild(node, 'key')
const identity = getBinaryNodeChildBuffer(node, 'identity')
const jid = node.attrs.jid
const registrationId = getBinaryNodeChildUInt(node, 'registration', 4)
const device = {
registrationId,
identityKey: generateSignalPubKey(identity),
signedPreKey: extractKey(signedKey),
preKey: extractKey(key)
} }
const cipher = new libsignal.SessionBuilder(signalStorage(auth), jidToSignalProtocolAddress(jid)) await Promise.all(
await cipher.initOutgoing(device) nodes.map(
async node => {
const signedKey = getBinaryNodeChild(node, 'skey')
const key = getBinaryNodeChild(node, 'key')
const identity = getBinaryNodeChildBuffer(node, 'identity')
const jid = node.attrs.jid
const registrationId = getBinaryNodeChildUInt(node, 'registration', 4)
const device = {
registrationId,
identityKey: generateSignalPubKey(identity),
signedPreKey: extractKey(signedKey),
preKey: extractKey(key)
}
const cipher = new libsignal.SessionBuilder(signalStorage(auth), jidToSignalProtocolAddress(jid))
await cipher.initOutgoing(device)
}
)
)
} }
export const extractDeviceJids = (result: BinaryNode, myJid: string, excludeZeroDevices: boolean) => { export const extractDeviceJids = (result: BinaryNode, myJid: string, excludeZeroDevices: boolean) => {