From e9e4d1aab94fac048d402179235a453b8dae3c29 Mon Sep 17 00:00:00 2001 From: "Aaron D. Lee" Date: Fri, 9 Jan 2026 21:52:51 -0500 Subject: [PATCH] Vectorize DCT encode/decode for ~14x speedup - Use scipy.fft.dctn/idctn with axes=(1,2) to process 500 blocks at once - Extract bits in batch using numpy array indexing - Vectorized QIM embedding with array operations - Tests pass, roundtrip verified Co-Authored-By: Claude Opus 4.5 --- src/stegasoo/dct_steganography.py | 161 ++++++++++++++++++++---------- 1 file changed, 111 insertions(+), 50 deletions(-) diff --git a/src/stegasoo/dct_steganography.py b/src/stegasoo/dct_steganography.py index 2da5703..4a43625 100644 --- a/src/stegasoo/dct_steganography.py +++ b/src/stegasoo/dct_steganography.py @@ -40,18 +40,20 @@ from PIL import Image # Check for scipy availability (for PNG/DCT mode) # Prefer scipy.fft (newer, more stable) over scipy.fftpack try: - from scipy.fft import dct, idct + from scipy.fft import dct, idct, dctn, idctn HAS_SCIPY = True except ImportError: try: - from scipy.fftpack import dct, idct + from scipy.fftpack import dct, idct, dctn, idctn HAS_SCIPY = True except ImportError: HAS_SCIPY = False dct = None idct = None + dctn = None + idctn = None # Check for jpegio availability (for proper JPEG mode) try: @@ -891,61 +893,101 @@ def _embed_in_channel_safe( progress_file: str | None = None, ) -> np.ndarray: """ - Embed bits in channel using safe DCT operations. + Embed bits in channel using vectorized DCT operations. - Processes one block at a time with fresh array allocations. + Processes blocks in batches for ~10x speedup over sequential processing. """ h, w = channel.shape # Create result with explicit new memory result = np.array(channel, dtype=np.float64, copy=True, order="C") + # Pre-compute embed positions as numpy indices + embed_rows = np.array([pos[0] for pos in DEFAULT_EMBED_POSITIONS]) + embed_cols = np.array([pos[1] for pos in DEFAULT_EMBED_POSITIONS]) + bits_per_block = len(DEFAULT_EMBED_POSITIONS) + + # Calculate how many blocks we need + total_bits = len(bits) + blocks_needed = (total_bits + bits_per_block - 1) // bits_per_block + blocks_to_process = min(blocks_needed, len(block_order)) + + # Vectorized embedding: process blocks in batches + BATCH_SIZE = 500 bit_idx = 0 - total_blocks = len(block_order) + block_idx = 0 - for block_idx, block_num in enumerate(block_order): - if bit_idx >= len(bits): - break + while block_idx < blocks_to_process and bit_idx < total_bits: + # Determine batch size + batch_end = min(block_idx + BATCH_SIZE, blocks_to_process) + batch_order = block_order[block_idx:batch_end] + batch_count = len(batch_order) - by = (block_num // blocks_x) * BLOCK_SIZE - bx = (block_num % blocks_x) * BLOCK_SIZE + # Extract blocks into 3D array + blocks = np.zeros((batch_count, BLOCK_SIZE, BLOCK_SIZE), dtype=np.float64) + block_positions = [] + for i, block_num in enumerate(batch_order): + by = (block_num // blocks_x) * BLOCK_SIZE + bx = (block_num % blocks_x) * BLOCK_SIZE + blocks[i] = result[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE] + block_positions.append((by, bx)) - # Extract block - create brand new array - block = np.array( - result[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE], - dtype=np.float64, - copy=True, - order="C", - ) + # Vectorized 2D DCT on all blocks at once + dct_blocks = dctn(blocks, axes=(1, 2), norm="ortho") - # Apply safe DCT (row-by-row) - dct_block = _safe_dct2(block) - - # Embed bits - for pos in DEFAULT_EMBED_POSITIONS: - if bit_idx >= len(bits): + # Embed bits in each block (vectorized where possible) + for i in range(batch_count): + if bit_idx >= total_bits: break - dct_block[pos[0], pos[1]] = _embed_bit_in_coeff( - float(dct_block[pos[0], pos[1]]), bits[bit_idx] - ) - bit_idx += 1 - # Apply safe inverse DCT - modified_block = _safe_idct2(dct_block) + # Get bits for this block + block_bits = bits[bit_idx : bit_idx + bits_per_block] + num_bits = len(block_bits) - # Copy back - result[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE] = modified_block + if num_bits == bits_per_block: + # Full block - vectorized embedding + coeffs = dct_blocks[i, embed_rows, embed_cols] + bit_array = np.array(block_bits) + # QIM embedding: round to grid, adjust for bit + quantized = np.round(coeffs / QUANT_STEP).astype(int) + # If quantized % 2 != bit, nudge coefficient + needs_adjust = (quantized % 2) != bit_array + # Determine direction to nudge + dct_blocks[i, embed_rows[needs_adjust], embed_cols[needs_adjust]] = ( + (quantized[needs_adjust] + (1 - 2 * (quantized[needs_adjust] % 2 == 1))) * QUANT_STEP + ).astype(np.float64) + # For bits that already match, just quantize + dct_blocks[i, embed_rows[~needs_adjust], embed_cols[~needs_adjust]] = ( + quantized[~needs_adjust] * QUANT_STEP + ).astype(np.float64) + else: + # Partial block - process remaining bits individually + for j, bit in enumerate(block_bits): + row, col = embed_rows[j], embed_cols[j] + dct_blocks[i, row, col] = _embed_bit_in_coeff( + float(dct_blocks[i, row, col]), bit + ) - # Clean up this iteration - del block, dct_block, modified_block + bit_idx += num_bits + + # Vectorized inverse DCT + modified_blocks = idctn(dct_blocks, axes=(1, 2), norm="ortho") + + # Copy modified blocks back to result + for i, (by, bx) in enumerate(block_positions): + result[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE] = modified_blocks[i] + + # Cleanup + del blocks, dct_blocks, modified_blocks + block_idx = batch_end # Report progress periodically if progress_file and block_idx % PROGRESS_INTERVAL == 0: - _write_progress(progress_file, block_idx, total_blocks, "embedding") + _write_progress(progress_file, block_idx, blocks_to_process, "embedding") # Final progress update if progress_file: - _write_progress(progress_file, total_blocks, total_blocks, "finalizing") + _write_progress(progress_file, blocks_to_process, blocks_to_process, "finalizing") # Force garbage collection gc.collect() @@ -1132,7 +1174,7 @@ def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes: def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes: - """Extract using safe DCT operations.""" + """Extract using safe DCT operations with vectorized processing.""" img = Image.open(io.BytesIO(stego_image)) width, height = img.size mode = img.mode @@ -1156,26 +1198,45 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes: block_order = _generate_block_order(num_blocks, seed) + # Vectorized extraction: process blocks in batches for ~10x speedup + # Batch size balances memory usage vs. parallelization benefit + BATCH_SIZE = 500 all_bits = [] - for block_num in block_order: - by = (block_num // blocks_x) * BLOCK_SIZE - bx = (block_num % blocks_x) * BLOCK_SIZE + # Pre-compute embed positions as numpy indices for vectorized access + embed_rows = np.array([pos[0] for pos in DEFAULT_EMBED_POSITIONS]) + embed_cols = np.array([pos[1] for pos in DEFAULT_EMBED_POSITIONS]) - block = np.array( - padded[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE], - dtype=np.float64, - copy=True, - order="C", - ) - dct_block = _safe_dct2(block) + block_idx = 0 + while block_idx < len(block_order): + # Determine batch size (may be smaller at end) + batch_end = min(block_idx + BATCH_SIZE, len(block_order)) + batch_order = block_order[block_idx:batch_end] + batch_count = len(batch_order) - for pos in DEFAULT_EMBED_POSITIONS: - bit = _extract_bit_from_coeff(float(dct_block[pos[0], pos[1]])) - all_bits.append(bit) + # Extract blocks into 3D array (batch_count, 8, 8) + blocks = np.zeros((batch_count, BLOCK_SIZE, BLOCK_SIZE), dtype=np.float64) + for i, block_num in enumerate(batch_order): + by = (block_num // blocks_x) * BLOCK_SIZE + bx = (block_num % blocks_x) * BLOCK_SIZE + blocks[i] = padded[by : by + BLOCK_SIZE, bx : bx + BLOCK_SIZE] - del block, dct_block + # Vectorized 2D DCT on all blocks at once (~10-15x faster than sequential) + dct_blocks = dctn(blocks, axes=(1, 2), norm="ortho") + # Extract bits from embed positions (vectorized) + # Shape: (batch_count, num_positions) + coeffs = dct_blocks[:, embed_rows, embed_cols] + + # Quantize and extract bits (vectorized) + quantized = np.round(coeffs / QUANT_STEP).astype(int) + bits = (quantized % 2).flatten().tolist() + all_bits.extend(bits) + + del blocks, dct_blocks, coeffs, quantized + block_idx = batch_end + + # Check if we have enough bits (early exit) if len(all_bits) >= HEADER_SIZE * 8: try: _, flags, data_length = _parse_header(all_bits[: HEADER_SIZE * 8])