From b5ac28d4263c665f2d373c67daf9376912844306 Mon Sep 17 00:00:00 2001 From: Adhiraj Singh Date: Thu, 2 Dec 2021 11:38:24 +0530 Subject: [PATCH] feat: implement partial media downloads --- jest.config.js | 11 +++++ src/Tests/test.media-download.ts | 64 ++++++++++++++++++++++++ src/Types/Message.ts | 2 + src/Utils/messages-media.ts | 84 ++++++++++++++++++++++++++++---- 4 files changed, 152 insertions(+), 9 deletions(-) create mode 100644 jest.config.js create mode 100644 src/Tests/test.media-download.ts diff --git a/jest.config.js b/jest.config.js new file mode 100644 index 0000000..af9e3d4 --- /dev/null +++ b/jest.config.js @@ -0,0 +1,11 @@ +module.exports = { + "roots": [ + "/src" + ], + "testMatch": [ + "**/Tests/test.*.+(ts|tsx|js)", + ], + "transform": { + "^.+\\.(ts|tsx)$": "ts-jest" + }, +} \ No newline at end of file diff --git a/src/Tests/test.media-download.ts b/src/Tests/test.media-download.ts new file mode 100644 index 0000000..a0185a6 --- /dev/null +++ b/src/Tests/test.media-download.ts @@ -0,0 +1,64 @@ +import { MediaType, DownloadableMessage } from '../Types' +import { downloadContentFromMessage } from '../Utils' +import { proto } from '../../WAProto' +import { readFileSync } from 'fs' + +type TestVector = { + type: MediaType + message: DownloadableMessage + plaintext: Buffer +} + +const TEST_VECTORS: TestVector[] = [ + { + type: 'image', + message: proto.ImageMessage.decode( + Buffer.from( + 'Ck1odHRwczovL21tZy53aGF0c2FwcC5uZXQvZC9mL0FwaHR4WG9fWXZZcDZlUVNSa0tjOHE5d2ozVUpleWdoY3poM3ExX3I0ektnLmVuYxIKaW1hZ2UvanBlZyIgKTuVFyxDc6mTm4GXPlO3Z911Wd8RBeTrPLSWAEdqW8MomcUBQiB7wH5a4nXMKyLOT0A2nFgnnM/DUH8YjQf8QtkCIekaSkogTB+BXKCWDFrmNzozY0DCPn0L4VKd7yG1ZbZwbgRhzVc=', + 'base64' + ) + ), + plaintext: readFileSync('./Media/cat.jpeg') + } +] + +describe('Media Download Tests', () => { + + it('should download a full encrypted media correctly', async() => { + for(const { type, message, plaintext } of TEST_VECTORS) { + const readPipe = await downloadContentFromMessage(message, type) + + let buffer = Buffer.alloc(0) + for await(const read of readPipe) { + buffer = Buffer.concat([ buffer, read ]) + } + + expect(buffer).toEqual(plaintext) + } + }) + + it('should download an encrypted media correctly piece', async() => { + for(const { type, message, plaintext } of TEST_VECTORS) { + // check all edge cases + const ranges = [ + { startByte: 51, endByte: plaintext.length-100 }, // random numbers + { startByte: 1024, endByte: 2038 }, // larger random multiples of 16 + { startByte: 1, endByte: plaintext.length-1 } // borders + ] + for(const range of ranges) { + const readPipe = await downloadContentFromMessage(message, type, range) + + let buffer = Buffer.alloc(0) + for await(const read of readPipe) { + buffer = Buffer.concat([ buffer, read ]) + } + + const hex = buffer.toString('hex') + const expectedHex = plaintext.slice(range.startByte || 0, range.endByte || undefined).toString('hex') + expect(hex).toBe(expectedHex) + + console.log('success on ', range) + } + } + }) +}) \ No newline at end of file diff --git a/src/Types/Message.ts b/src/Types/Message.ts index ac4ccb5..fa3bc14 100644 --- a/src/Types/Message.ts +++ b/src/Types/Message.ts @@ -23,6 +23,8 @@ export type WAMediaUpload = Buffer | { url: URL | string } | { stream: Readable /** Set of message types that are supported by the library */ export type MessageType = keyof proto.Message +export type DownloadableMessage = { mediaKey?: Uint8Array, directPath?: string, url?: string } + export type MediaConnInfo = { auth: string ttl: number diff --git a/src/Utils/messages-media.ts b/src/Utils/messages-media.ts index a32342a..78b9df2 100644 --- a/src/Utils/messages-media.ts +++ b/src/Utils/messages-media.ts @@ -10,7 +10,7 @@ import { URL } from 'url' import { join } from 'path' import { once } from 'events' import got, { Options, Response } from 'got' -import { MessageType, WAMessageContent, WAProto, WAGenericMediaMessage, WAMediaUpload, MediaType } from '../Types' +import { MessageType, WAMessageContent, WAProto, WAGenericMediaMessage, WAMediaUpload, MediaType, DownloadableMessage } from '../Types' import { generateMessageID } from './generics' import { hkdf } from './crypto' import { DEFAULT_ORIGIN } from '../Defaults' @@ -223,30 +223,96 @@ export const encryptedStream = async(media: WAMediaUpload, mediaType: MediaType, didSaveToTmpPath } } + const DEF_HOST = 'mmg.whatsapp.net' +const AES_CHUNK_SIZE = 16 + +const toSmallestChunkSize = (num: number) => { + return Math.floor(num / AES_CHUNK_SIZE) * AES_CHUNK_SIZE +} + +type MediaDownloadOptions = { + startByte?: number + endByte?: number +} + export const downloadContentFromMessage = async( - { mediaKey, directPath, url }: { mediaKey?: Uint8Array, directPath?: string, url?: string }, - type: MediaType + { mediaKey, directPath, url }: DownloadableMessage, + type: MediaType, + { startByte, endByte }: MediaDownloadOptions = { } ) => { const downloadUrl = url || `https://${DEF_HOST}${directPath}` + let bytesFetched = 0 + let startChunk = 0 + let firstBlockIsIV = false + // if a start byte is specified -- then we need to fetch the previous chunk as that will form the IV + if(startByte) { + const chunk = toSmallestChunkSize(startByte || 0) + if(chunk) { + startChunk = chunk-AES_CHUNK_SIZE + bytesFetched = chunk + + firstBlockIsIV = true + } + } + const endChunk = endByte ? toSmallestChunkSize(endByte || 0)+AES_CHUNK_SIZE : undefined + let rangeHeader: string | undefined = undefined + if(startChunk || endChunk) { + rangeHeader = `bytes=${startChunk}-` + if(endChunk) rangeHeader += endChunk + } // download the message const fetched = await getGotStream(downloadUrl, { - headers: { Origin: DEFAULT_ORIGIN } + headers: { + Origin: DEFAULT_ORIGIN, + Range: rangeHeader + } }) + let remainingBytes = Buffer.from([]) const { cipherKey, iv } = getMediaKeys(mediaKey, type) - const aes = Crypto.createDecipheriv("aes-256-cbc", cipherKey, iv) + + let aes: Crypto.Decipher + + const pushBytes = (bytes: Buffer, push: (bytes: Buffer) => void) => { + if(startByte || endByte) { + const start = bytesFetched >= startByte ? undefined : Math.max(startByte-bytesFetched, 0) + const end = bytesFetched+bytes.length < endByte ? undefined : Math.max(endByte-bytesFetched, 0) + + push(bytes.slice(start, end)) + + bytesFetched += bytes.length + } else { + push(bytes) + } + } const output = new Transform({ transform(chunk, _, callback) { let data = Buffer.concat([remainingBytes, chunk]) - const decryptLength = - Math.floor(data.length / 16) * 16 + + const decryptLength = toSmallestChunkSize(data.length) remainingBytes = data.slice(decryptLength) data = data.slice(0, decryptLength) + if(!aes) { + let ivValue = iv + if(firstBlockIsIV) { + ivValue = data.slice(0, AES_CHUNK_SIZE) + data = data.slice(AES_CHUNK_SIZE) + } + + aes = Crypto.createDecipheriv("aes-256-cbc", cipherKey, ivValue) + // if an end byte that is not EOF is specified + // stop auto padding (PKCS7) -- otherwise throws an error for decryption + if(endByte) { + aes.setAutoPadding(false) + } + + } + try { - this.push(aes.update(data)) + pushBytes(aes.update(data), b => this.push(b)) callback() } catch(error) { callback(error) @@ -254,7 +320,7 @@ export const downloadContentFromMessage = async( }, final(callback) { try { - this.push(aes.final()) + pushBytes(aes.final(), b => this.push(b)) callback() } catch(error) { callback(error)