diff --git a/src/Tests/test.key-store.ts b/src/Tests/test.key-store.ts new file mode 100644 index 0000000..f6b86b2 --- /dev/null +++ b/src/Tests/test.key-store.ts @@ -0,0 +1,92 @@ +import { addTransactionCapability, delay } from '../Utils' +import logger from '../Utils/logger' +import { makeMockSignalKeyStore } from './utils' + +logger.level = 'trace' + +describe('Key Store w Transaction Tests', () => { + + const rawStore = makeMockSignalKeyStore() + const store = addTransactionCapability( + rawStore, + logger, + { + maxCommitRetries: 1, + delayBetweenTriesMs: 10 + } + ) + + it('should use transaction cache when mutated', async() => { + const key = '123' + const value = new Uint8Array(1) + const ogGet = rawStore.get + await store.transaction( + async() => { + await store.set({ 'session': { [key]: value } }) + + rawStore.get = () => { + throw new Error('should not have been called') + } + + const { [key]: stored } = await store.get('session', [key]) + expect(stored).toEqual(new Uint8Array(1)) + } + ) + + rawStore.get = ogGet + }) + + it('should not commit a failed transaction', async() => { + const key = 'abcd' + await expect( + store.transaction( + async() => { + await store.set({ 'session': { [key]: new Uint8Array(1) } }) + throw new Error('fail') + } + ) + ).rejects.toThrowError( + 'fail' + ) + + const { [key]: stored } = await store.get('session', [key]) + expect(stored).toBeUndefined() + }) + + it('should handle overlapping transactions', async() => { + // promise to let transaction 2 + // know that transaction 1 has started + let promiseResolve: () => void + const promise = new Promise(resolve => { + promiseResolve = resolve + }) + + store.transaction( + async() => { + await store.set({ + 'session': { + '1': new Uint8Array(1) + } + }) + // wait for the other transaction to start + await delay(5) + // reolve the promise to let the other transaction continue + promiseResolve() + } + ) + + await store.transaction( + async() => { + await promise + await delay(5) + + expect(store.isInTransaction()).toBe(true) + } + ) + + expect(store.isInTransaction()).toBe(false) + // ensure that the transaction were committed + const { ['1']: stored } = await store.get('session', ['1']) + expect(stored).toEqual(new Uint8Array(1)) + }) +}) \ No newline at end of file diff --git a/src/Tests/utils.ts b/src/Tests/utils.ts index 9d5ce04..bcd1469 100644 --- a/src/Tests/utils.ts +++ b/src/Tests/utils.ts @@ -1,6 +1,36 @@ +import { SignalDataTypeMap, SignalKeyStore } from '../Types' import { jidEncode } from '../WABinary' - export function randomJid() { return jidEncode(Math.floor(Math.random() * 1000000), Math.random() < 0.5 ? 's.whatsapp.net' : 'g.us') +} + +export function makeMockSignalKeyStore(): SignalKeyStore { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const store: { [_: string]: any } = {} + + return { + get(type, ids) { + const data: { [_: string]: SignalDataTypeMap[typeof type] } = { } + for(const id of ids) { + const item = store[getUniqueId(type, id)] + if(typeof item !== 'undefined') { + data[id] = item + } + } + + return data + }, + set(data) { + for(const type in data) { + for(const id in data[type]) { + store[getUniqueId(type, id)] = data[type][id] + } + } + }, + } + + function getUniqueId(type: string, id: string) { + return `${type}.${id}` + } } \ No newline at end of file diff --git a/src/Utils/auth-utils.ts b/src/Utils/auth-utils.ts index 5ef31c9..aa176b6 100644 --- a/src/Utils/auth-utils.ts +++ b/src/Utils/auth-utils.ts @@ -87,19 +87,20 @@ export const addTransactionCapability = ( logger: Logger, { maxCommitRetries, delayBetweenTriesMs }: TransactionCapabilityOptions ): SignalKeyStoreWithTransaction => { - let inTransaction = false // number of queries made to the DB during the transaction // only there for logging purposes let dbQueriesInTransaction = 0 let transactionCache: SignalDataSet = { } let mutations: SignalDataSet = { } + let transactionsInProgress = 0 + return { get: async(type, ids) => { - if(inTransaction) { + if(isInTransaction()) { const dict = transactionCache[type] const idsRequiringFetch = dict - ? ids.filter(item => typeof dict[item] !== 'undefined') + ? ids.filter(item => typeof dict[item] === 'undefined') : ids // only fetch if there are any items to fetch if(idsRequiringFetch.length) { @@ -128,7 +129,7 @@ export const addTransactionCapability = ( } }, set: data => { - if(inTransaction) { + if(isInTransaction()) { logger.trace({ types: Object.keys(data) }, 'caching in transaction') for(const key in data) { transactionCache[key] = transactionCache[key] || { } @@ -141,18 +142,18 @@ export const addTransactionCapability = ( return state.set(data) } }, - isInTransaction: () => inTransaction, + isInTransaction, async transaction(work) { let result: Awaited> - // if we're already in a transaction, - // just execute what needs to be executed -- no commit required - if(inTransaction) { - result = await work() - } else { + transactionsInProgress += 1 + if(transactionsInProgress === 1) { logger.trace('entering transaction') - inTransaction = true - try { - result = await work() + } + + try { + result = await work() + // commit if this is the outermost transaction + if(transactionsInProgress === 1) { if(Object.keys(mutations).length) { logger.trace('committing transaction') // retry mechanism to ensure we've some recovery @@ -172,16 +173,23 @@ export const addTransactionCapability = ( } else { logger.trace('no mutations in transaction') } - } finally { - inTransaction = false + } + } finally { + transactionsInProgress -= 1 + if(transactionsInProgress === 0) { transactionCache = { } mutations = { } dbQueriesInTransaction = 0 } } + return result } } + + function isInTransaction() { + return transactionsInProgress > 0 + } } export const initAuthCreds = (): AuthenticationCreds => {