diff --git a/src/Tests/test.key-store.ts b/src/Tests/test.key-store.ts new file mode 100644 index 0000000..f6b86b2 --- /dev/null +++ b/src/Tests/test.key-store.ts @@ -0,0 +1,92 @@ +import { addTransactionCapability, delay } from '../Utils' +import logger from '../Utils/logger' +import { makeMockSignalKeyStore } from './utils' + +logger.level = 'trace' + +describe('Key Store w Transaction Tests', () => { + + const rawStore = makeMockSignalKeyStore() + const store = addTransactionCapability( + rawStore, + logger, + { + maxCommitRetries: 1, + delayBetweenTriesMs: 10 + } + ) + + it('should use transaction cache when mutated', async() => { + const key = '123' + const value = new Uint8Array(1) + const ogGet = rawStore.get + await store.transaction( + async() => { + await store.set({ 'session': { [key]: value } }) + + rawStore.get = () => { + throw new Error('should not have been called') + } + + const { [key]: stored } = await store.get('session', [key]) + expect(stored).toEqual(new Uint8Array(1)) + } + ) + + rawStore.get = ogGet + }) + + it('should not commit a failed transaction', async() => { + const key = 'abcd' + await expect( + store.transaction( + async() => { + await store.set({ 'session': { [key]: new Uint8Array(1) } }) + throw new Error('fail') + } + ) + ).rejects.toThrowError( + 'fail' + ) + + const { [key]: stored } = await store.get('session', [key]) + expect(stored).toBeUndefined() + }) + + it('should handle overlapping transactions', async() => { + // promise to let transaction 2 + // know that transaction 1 has started + let promiseResolve: () => void + const promise = new Promise(resolve => { + promiseResolve = resolve + }) + + store.transaction( + async() => { + await store.set({ + 'session': { + '1': new Uint8Array(1) + } + }) + // wait for the other transaction to start + await delay(5) + // reolve the promise to let the other transaction continue + promiseResolve() + } + ) + + await store.transaction( + async() => { + await promise + await delay(5) + + expect(store.isInTransaction()).toBe(true) + } + ) + + expect(store.isInTransaction()).toBe(false) + // ensure that the transaction were committed + const { ['1']: stored } = await store.get('session', ['1']) + expect(stored).toEqual(new Uint8Array(1)) + }) +}) \ No newline at end of file diff --git a/src/Tests/utils.ts b/src/Tests/utils.ts index 9d5ce04..bcd1469 100644 --- a/src/Tests/utils.ts +++ b/src/Tests/utils.ts @@ -1,6 +1,36 @@ +import { SignalDataTypeMap, SignalKeyStore } from '../Types' import { jidEncode } from '../WABinary' - export function randomJid() { return jidEncode(Math.floor(Math.random() * 1000000), Math.random() < 0.5 ? 's.whatsapp.net' : 'g.us') +} + +export function makeMockSignalKeyStore(): SignalKeyStore { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const store: { [_: string]: any } = {} + + return { + get(type, ids) { + const data: { [_: string]: SignalDataTypeMap[typeof type] } = { } + for(const id of ids) { + const item = store[getUniqueId(type, id)] + if(typeof item !== 'undefined') { + data[id] = item + } + } + + return data + }, + set(data) { + for(const type in data) { + for(const id in data[type]) { + store[getUniqueId(type, id)] = data[type][id] + } + } + }, + } + + function getUniqueId(type: string, id: string) { + return `${type}.${id}` + } } \ No newline at end of file diff --git a/src/Types/Auth.ts b/src/Types/Auth.ts index 72bde7c..6c06333 100644 --- a/src/Types/Auth.ts +++ b/src/Types/Auth.ts @@ -82,7 +82,7 @@ export type SignalKeyStore = { export type SignalKeyStoreWithTransaction = SignalKeyStore & { isInTransaction: () => boolean - transaction(exec: () => Promise): Promise + transaction(exec: () => Promise): Promise } export type TransactionCapabilityOptions = { @@ -92,7 +92,7 @@ export type TransactionCapabilityOptions = { export type SignalAuthState = { creds: SignalCreds - keys: SignalKeyStore + keys: SignalKeyStore | SignalKeyStoreWithTransaction } export type AuthenticationState = { diff --git a/src/Utils/auth-utils.ts b/src/Utils/auth-utils.ts index a4f2e73..4f41e12 100644 --- a/src/Utils/auth-utils.ts +++ b/src/Utils/auth-utils.ts @@ -86,39 +86,33 @@ export const addTransactionCapability = ( logger: Logger, { maxCommitRetries, delayBetweenTriesMs }: TransactionCapabilityOptions ): SignalKeyStoreWithTransaction => { - let inTransaction = false // number of queries made to the DB during the transaction // only there for logging purposes let dbQueriesInTransaction = 0 let transactionCache: SignalDataSet = { } let mutations: SignalDataSet = { } - /** - * prefetches some data and stores in memory, - * useful if these data points will be used together often - * */ - const prefetch = async(type: T, ids: string[]) => { - const dict = transactionCache[type] - const idsRequiringFetch = dict - ? ids.filter(item => typeof dict[item] !== 'undefined') - : ids - // only fetch if there are any items to fetch - if(idsRequiringFetch.length) { - dbQueriesInTransaction += 1 - const result = await state.get(type, idsRequiringFetch) - - transactionCache[type] ||= {} - transactionCache[type] = Object.assign( - transactionCache[type]!, - result - ) - } - } + let transactionsInProgress = 0 return { get: async(type, ids) => { - if(inTransaction) { - await prefetch(type, ids) + if(isInTransaction()) { + const dict = transactionCache[type] + const idsRequiringFetch = dict + ? ids.filter(item => typeof dict[item] === 'undefined') + : ids + // only fetch if there are any items to fetch + if(idsRequiringFetch.length) { + dbQueriesInTransaction += 1 + const result = await state.get(type, idsRequiringFetch) + + transactionCache[type] ||= {} + Object.assign( + transactionCache[type]!, + result + ) + } + return ids.reduce( (dict, id) => { const value = transactionCache[type]?.[id] @@ -134,7 +128,7 @@ export const addTransactionCapability = ( } }, set: data => { - if(inTransaction) { + if(isInTransaction()) { logger.trace({ types: Object.keys(data) }, 'caching in transaction') for(const key in data) { transactionCache[key] = transactionCache[key] || { } @@ -147,17 +141,18 @@ export const addTransactionCapability = ( return state.set(data) } }, - isInTransaction: () => inTransaction, - transaction: async(work) => { - // if we're already in a transaction, - // just execute what needs to be executed -- no commit required - if(inTransaction) { - await work() - } else { + isInTransaction, + async transaction(work) { + let result: Awaited> + transactionsInProgress += 1 + if(transactionsInProgress === 1) { logger.trace('entering transaction') - inTransaction = true - try { - await work() + } + + try { + result = await work() + // commit if this is the outermost transaction + if(transactionsInProgress === 1) { if(Object.keys(mutations).length) { logger.trace('committing transaction') // retry mechanism to ensure we've some recovery @@ -177,15 +172,23 @@ export const addTransactionCapability = ( } else { logger.trace('no mutations in transaction') } - } finally { - inTransaction = false + } + } finally { + transactionsInProgress -= 1 + if(transactionsInProgress === 0) { transactionCache = { } mutations = { } dbQueriesInTransaction = 0 } } + + return result } } + + function isInTransaction() { + return transactionsInProgress > 0 + } } export const initAuthCreds = (): AuthenticationCreds => {