mirror of
https://github.com/FranP-code/Baileys.git
synced 2025-10-13 00:32:22 +00:00
fix: auth store transactions + tests
This commit is contained in:
92
src/Tests/test.key-store.ts
Normal file
92
src/Tests/test.key-store.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import { addTransactionCapability, delay } from '../Utils'
|
||||
import logger from '../Utils/logger'
|
||||
import { makeMockSignalKeyStore } from './utils'
|
||||
|
||||
logger.level = 'trace'
|
||||
|
||||
describe('Key Store w Transaction Tests', () => {
|
||||
|
||||
const rawStore = makeMockSignalKeyStore()
|
||||
const store = addTransactionCapability(
|
||||
rawStore,
|
||||
logger,
|
||||
{
|
||||
maxCommitRetries: 1,
|
||||
delayBetweenTriesMs: 10
|
||||
}
|
||||
)
|
||||
|
||||
it('should use transaction cache when mutated', async() => {
|
||||
const key = '123'
|
||||
const value = new Uint8Array(1)
|
||||
const ogGet = rawStore.get
|
||||
await store.transaction(
|
||||
async() => {
|
||||
await store.set({ 'session': { [key]: value } })
|
||||
|
||||
rawStore.get = () => {
|
||||
throw new Error('should not have been called')
|
||||
}
|
||||
|
||||
const { [key]: stored } = await store.get('session', [key])
|
||||
expect(stored).toEqual(new Uint8Array(1))
|
||||
}
|
||||
)
|
||||
|
||||
rawStore.get = ogGet
|
||||
})
|
||||
|
||||
it('should not commit a failed transaction', async() => {
|
||||
const key = 'abcd'
|
||||
await expect(
|
||||
store.transaction(
|
||||
async() => {
|
||||
await store.set({ 'session': { [key]: new Uint8Array(1) } })
|
||||
throw new Error('fail')
|
||||
}
|
||||
)
|
||||
).rejects.toThrowError(
|
||||
'fail'
|
||||
)
|
||||
|
||||
const { [key]: stored } = await store.get('session', [key])
|
||||
expect(stored).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle overlapping transactions', async() => {
|
||||
// promise to let transaction 2
|
||||
// know that transaction 1 has started
|
||||
let promiseResolve: () => void
|
||||
const promise = new Promise<void>(resolve => {
|
||||
promiseResolve = resolve
|
||||
})
|
||||
|
||||
store.transaction(
|
||||
async() => {
|
||||
await store.set({
|
||||
'session': {
|
||||
'1': new Uint8Array(1)
|
||||
}
|
||||
})
|
||||
// wait for the other transaction to start
|
||||
await delay(5)
|
||||
// reolve the promise to let the other transaction continue
|
||||
promiseResolve()
|
||||
}
|
||||
)
|
||||
|
||||
await store.transaction(
|
||||
async() => {
|
||||
await promise
|
||||
await delay(5)
|
||||
|
||||
expect(store.isInTransaction()).toBe(true)
|
||||
}
|
||||
)
|
||||
|
||||
expect(store.isInTransaction()).toBe(false)
|
||||
// ensure that the transaction were committed
|
||||
const { ['1']: stored } = await store.get('session', ['1'])
|
||||
expect(stored).toEqual(new Uint8Array(1))
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,36 @@
|
||||
import { SignalDataTypeMap, SignalKeyStore } from '../Types'
|
||||
import { jidEncode } from '../WABinary'
|
||||
|
||||
|
||||
export function randomJid() {
|
||||
return jidEncode(Math.floor(Math.random() * 1000000), Math.random() < 0.5 ? 's.whatsapp.net' : 'g.us')
|
||||
}
|
||||
|
||||
export function makeMockSignalKeyStore(): SignalKeyStore {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const store: { [_: string]: any } = {}
|
||||
|
||||
return {
|
||||
get(type, ids) {
|
||||
const data: { [_: string]: SignalDataTypeMap[typeof type] } = { }
|
||||
for(const id of ids) {
|
||||
const item = store[getUniqueId(type, id)]
|
||||
if(typeof item !== 'undefined') {
|
||||
data[id] = item
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
},
|
||||
set(data) {
|
||||
for(const type in data) {
|
||||
for(const id in data[type]) {
|
||||
store[getUniqueId(type, id)] = data[type][id]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
function getUniqueId(type: string, id: string) {
|
||||
return `${type}.${id}`
|
||||
}
|
||||
}
|
||||
@@ -87,19 +87,20 @@ export const addTransactionCapability = (
|
||||
logger: Logger,
|
||||
{ maxCommitRetries, delayBetweenTriesMs }: TransactionCapabilityOptions
|
||||
): SignalKeyStoreWithTransaction => {
|
||||
let inTransaction = false
|
||||
// number of queries made to the DB during the transaction
|
||||
// only there for logging purposes
|
||||
let dbQueriesInTransaction = 0
|
||||
let transactionCache: SignalDataSet = { }
|
||||
let mutations: SignalDataSet = { }
|
||||
|
||||
let transactionsInProgress = 0
|
||||
|
||||
return {
|
||||
get: async(type, ids) => {
|
||||
if(inTransaction) {
|
||||
if(isInTransaction()) {
|
||||
const dict = transactionCache[type]
|
||||
const idsRequiringFetch = dict
|
||||
? ids.filter(item => typeof dict[item] !== 'undefined')
|
||||
? ids.filter(item => typeof dict[item] === 'undefined')
|
||||
: ids
|
||||
// only fetch if there are any items to fetch
|
||||
if(idsRequiringFetch.length) {
|
||||
@@ -128,7 +129,7 @@ export const addTransactionCapability = (
|
||||
}
|
||||
},
|
||||
set: data => {
|
||||
if(inTransaction) {
|
||||
if(isInTransaction()) {
|
||||
logger.trace({ types: Object.keys(data) }, 'caching in transaction')
|
||||
for(const key in data) {
|
||||
transactionCache[key] = transactionCache[key] || { }
|
||||
@@ -141,18 +142,18 @@ export const addTransactionCapability = (
|
||||
return state.set(data)
|
||||
}
|
||||
},
|
||||
isInTransaction: () => inTransaction,
|
||||
isInTransaction,
|
||||
async transaction(work) {
|
||||
let result: Awaited<ReturnType<typeof work>>
|
||||
// if we're already in a transaction,
|
||||
// just execute what needs to be executed -- no commit required
|
||||
if(inTransaction) {
|
||||
result = await work()
|
||||
} else {
|
||||
transactionsInProgress += 1
|
||||
if(transactionsInProgress === 1) {
|
||||
logger.trace('entering transaction')
|
||||
inTransaction = true
|
||||
try {
|
||||
result = await work()
|
||||
}
|
||||
|
||||
try {
|
||||
result = await work()
|
||||
// commit if this is the outermost transaction
|
||||
if(transactionsInProgress === 1) {
|
||||
if(Object.keys(mutations).length) {
|
||||
logger.trace('committing transaction')
|
||||
// retry mechanism to ensure we've some recovery
|
||||
@@ -172,16 +173,23 @@ export const addTransactionCapability = (
|
||||
} else {
|
||||
logger.trace('no mutations in transaction')
|
||||
}
|
||||
} finally {
|
||||
inTransaction = false
|
||||
}
|
||||
} finally {
|
||||
transactionsInProgress -= 1
|
||||
if(transactionsInProgress === 0) {
|
||||
transactionCache = { }
|
||||
mutations = { }
|
||||
dbQueriesInTransaction = 0
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
function isInTransaction() {
|
||||
return transactionsInProgress > 0
|
||||
}
|
||||
}
|
||||
|
||||
export const initAuthCreds = (): AuthenticationCreds => {
|
||||
|
||||
Reference in New Issue
Block a user