feat: implement partial media downloads

This commit is contained in:
Adhiraj Singh
2021-12-02 11:38:24 +05:30
parent cd9c313e47
commit b5ac28d426
4 changed files with 152 additions and 9 deletions

11
jest.config.js Normal file
View File

@@ -0,0 +1,11 @@
module.exports = {
"roots": [
"<rootDir>/src"
],
"testMatch": [
"**/Tests/test.*.+(ts|tsx|js)",
],
"transform": {
"^.+\\.(ts|tsx)$": "ts-jest"
},
}

View File

@@ -0,0 +1,64 @@
import { MediaType, DownloadableMessage } from '../Types'
import { downloadContentFromMessage } from '../Utils'
import { proto } from '../../WAProto'
import { readFileSync } from 'fs'
type TestVector = {
type: MediaType
message: DownloadableMessage
plaintext: Buffer
}
const TEST_VECTORS: TestVector[] = [
{
type: 'image',
message: proto.ImageMessage.decode(
Buffer.from(
'Ck1odHRwczovL21tZy53aGF0c2FwcC5uZXQvZC9mL0FwaHR4WG9fWXZZcDZlUVNSa0tjOHE5d2ozVUpleWdoY3poM3ExX3I0ektnLmVuYxIKaW1hZ2UvanBlZyIgKTuVFyxDc6mTm4GXPlO3Z911Wd8RBeTrPLSWAEdqW8MomcUBQiB7wH5a4nXMKyLOT0A2nFgnnM/DUH8YjQf8QtkCIekaSkogTB+BXKCWDFrmNzozY0DCPn0L4VKd7yG1ZbZwbgRhzVc=',
'base64'
)
),
plaintext: readFileSync('./Media/cat.jpeg')
}
]
describe('Media Download Tests', () => {
it('should download a full encrypted media correctly', async() => {
for(const { type, message, plaintext } of TEST_VECTORS) {
const readPipe = await downloadContentFromMessage(message, type)
let buffer = Buffer.alloc(0)
for await(const read of readPipe) {
buffer = Buffer.concat([ buffer, read ])
}
expect(buffer).toEqual(plaintext)
}
})
it('should download an encrypted media correctly piece', async() => {
for(const { type, message, plaintext } of TEST_VECTORS) {
// check all edge cases
const ranges = [
{ startByte: 51, endByte: plaintext.length-100 }, // random numbers
{ startByte: 1024, endByte: 2038 }, // larger random multiples of 16
{ startByte: 1, endByte: plaintext.length-1 } // borders
]
for(const range of ranges) {
const readPipe = await downloadContentFromMessage(message, type, range)
let buffer = Buffer.alloc(0)
for await(const read of readPipe) {
buffer = Buffer.concat([ buffer, read ])
}
const hex = buffer.toString('hex')
const expectedHex = plaintext.slice(range.startByte || 0, range.endByte || undefined).toString('hex')
expect(hex).toBe(expectedHex)
console.log('success on ', range)
}
}
})
})

View File

@@ -23,6 +23,8 @@ export type WAMediaUpload = Buffer | { url: URL | string } | { stream: Readable
/** Set of message types that are supported by the library */
export type MessageType = keyof proto.Message
export type DownloadableMessage = { mediaKey?: Uint8Array, directPath?: string, url?: string }
export type MediaConnInfo = {
auth: string
ttl: number

View File

@@ -10,7 +10,7 @@ import { URL } from 'url'
import { join } from 'path'
import { once } from 'events'
import got, { Options, Response } from 'got'
import { MessageType, WAMessageContent, WAProto, WAGenericMediaMessage, WAMediaUpload, MediaType } from '../Types'
import { MessageType, WAMessageContent, WAProto, WAGenericMediaMessage, WAMediaUpload, MediaType, DownloadableMessage } from '../Types'
import { generateMessageID } from './generics'
import { hkdf } from './crypto'
import { DEFAULT_ORIGIN } from '../Defaults'
@@ -223,30 +223,96 @@ export const encryptedStream = async(media: WAMediaUpload, mediaType: MediaType,
didSaveToTmpPath
}
}
const DEF_HOST = 'mmg.whatsapp.net'
const AES_CHUNK_SIZE = 16
const toSmallestChunkSize = (num: number) => {
return Math.floor(num / AES_CHUNK_SIZE) * AES_CHUNK_SIZE
}
type MediaDownloadOptions = {
startByte?: number
endByte?: number
}
export const downloadContentFromMessage = async(
{ mediaKey, directPath, url }: { mediaKey?: Uint8Array, directPath?: string, url?: string },
type: MediaType
{ mediaKey, directPath, url }: DownloadableMessage,
type: MediaType,
{ startByte, endByte }: MediaDownloadOptions = { }
) => {
const downloadUrl = url || `https://${DEF_HOST}${directPath}`
let bytesFetched = 0
let startChunk = 0
let firstBlockIsIV = false
// if a start byte is specified -- then we need to fetch the previous chunk as that will form the IV
if(startByte) {
const chunk = toSmallestChunkSize(startByte || 0)
if(chunk) {
startChunk = chunk-AES_CHUNK_SIZE
bytesFetched = chunk
firstBlockIsIV = true
}
}
const endChunk = endByte ? toSmallestChunkSize(endByte || 0)+AES_CHUNK_SIZE : undefined
let rangeHeader: string | undefined = undefined
if(startChunk || endChunk) {
rangeHeader = `bytes=${startChunk}-`
if(endChunk) rangeHeader += endChunk
}
// download the message
const fetched = await getGotStream(downloadUrl, {
headers: { Origin: DEFAULT_ORIGIN }
headers: {
Origin: DEFAULT_ORIGIN,
Range: rangeHeader
}
})
let remainingBytes = Buffer.from([])
const { cipherKey, iv } = getMediaKeys(mediaKey, type)
const aes = Crypto.createDecipheriv("aes-256-cbc", cipherKey, iv)
let aes: Crypto.Decipher
const pushBytes = (bytes: Buffer, push: (bytes: Buffer) => void) => {
if(startByte || endByte) {
const start = bytesFetched >= startByte ? undefined : Math.max(startByte-bytesFetched, 0)
const end = bytesFetched+bytes.length < endByte ? undefined : Math.max(endByte-bytesFetched, 0)
push(bytes.slice(start, end))
bytesFetched += bytes.length
} else {
push(bytes)
}
}
const output = new Transform({
transform(chunk, _, callback) {
let data = Buffer.concat([remainingBytes, chunk])
const decryptLength =
Math.floor(data.length / 16) * 16
const decryptLength = toSmallestChunkSize(data.length)
remainingBytes = data.slice(decryptLength)
data = data.slice(0, decryptLength)
if(!aes) {
let ivValue = iv
if(firstBlockIsIV) {
ivValue = data.slice(0, AES_CHUNK_SIZE)
data = data.slice(AES_CHUNK_SIZE)
}
aes = Crypto.createDecipheriv("aes-256-cbc", cipherKey, ivValue)
// if an end byte that is not EOF is specified
// stop auto padding (PKCS7) -- otherwise throws an error for decryption
if(endByte) {
aes.setAutoPadding(false)
}
}
try {
this.push(aes.update(data))
pushBytes(aes.update(data), b => this.push(b))
callback()
} catch(error) {
callback(error)
@@ -254,7 +320,7 @@ export const downloadContentFromMessage = async(
},
final(callback) {
try {
this.push(aes.final())
pushBytes(aes.final(), b => this.push(b))
callback()
} catch(error) {
callback(error)