Add progress_file support to DCT extraction

- Added progress_file parameter to extract_from_dct, _extract_scipy_dct_safe, _extract_jpegio
- Progress writes at key phases: loading, extracting, decoding, complete
- Updated extract_from_image and _extract_dct to pass through progress_file
- Updated decode(), decode_file(), decode_text() with progress_file param
- Progress JSON format: {current, total, percent, phase}

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee
2026-01-09 22:01:20 -05:00
parent e9e4d1aab9
commit c0fe85ac83
3 changed files with 56 additions and 9 deletions

View File

@@ -1157,7 +1157,11 @@ def _embed_jpegio(
pass pass
def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes: def extract_from_dct(
stego_image: bytes,
seed: bytes,
progress_file: str | None = None,
) -> bytes:
"""Extract data from DCT stego image.""" """Extract data from DCT stego image."""
img = Image.open(io.BytesIO(stego_image)) img = Image.open(io.BytesIO(stego_image))
fmt = img.format fmt = img.format
@@ -1165,16 +1169,22 @@ def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes:
if fmt == "JPEG" and HAS_JPEGIO: if fmt == "JPEG" and HAS_JPEGIO:
try: try:
return _extract_jpegio(stego_image, seed) return _extract_jpegio(stego_image, seed, progress_file)
except ValueError: except ValueError:
pass pass
_check_scipy() _check_scipy()
return _extract_scipy_dct_safe(stego_image, seed) return _extract_scipy_dct_safe(stego_image, seed, progress_file)
def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes: def _extract_scipy_dct_safe(
stego_image: bytes,
seed: bytes,
progress_file: str | None = None,
) -> bytes:
"""Extract using safe DCT operations with vectorized processing.""" """Extract using safe DCT operations with vectorized processing."""
_write_progress(progress_file, 0, 100, "loading")
img = Image.open(io.BytesIO(stego_image)) img = Image.open(io.BytesIO(stego_image))
width, height = img.size width, height = img.size
mode = img.mode mode = img.mode
@@ -1207,6 +1217,9 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
embed_rows = np.array([pos[0] for pos in DEFAULT_EMBED_POSITIONS]) embed_rows = np.array([pos[0] for pos in DEFAULT_EMBED_POSITIONS])
embed_cols = np.array([pos[1] for pos in DEFAULT_EMBED_POSITIONS]) embed_cols = np.array([pos[1] for pos in DEFAULT_EMBED_POSITIONS])
# Progress reporting interval
PROGRESS_INTERVAL = 2000 # Report every N blocks
block_idx = 0 block_idx = 0
while block_idx < len(block_order): while block_idx < len(block_order):
# Determine batch size (may be smaller at end) # Determine batch size (may be smaller at end)
@@ -1236,6 +1249,10 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
del blocks, dct_blocks, coeffs, quantized del blocks, dct_blocks, coeffs, quantized
block_idx = batch_end block_idx = batch_end
# Report progress
if progress_file and block_idx % PROGRESS_INTERVAL < BATCH_SIZE:
_write_progress(progress_file, block_idx, num_blocks, "extracting")
# Check if we have enough bits (early exit) # Check if we have enough bits (early exit)
if len(all_bits) >= HEADER_SIZE * 8: if len(all_bits) >= HEADER_SIZE * 8:
try: try:
@@ -1249,6 +1266,8 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
del padded del padded
gc.collect() gc.collect()
_write_progress(progress_file, 80, 100, "decoding")
# Try RS-protected format first (has 24-byte length prefix: 3 copies of 8-byte header) # Try RS-protected format first (has 24-byte length prefix: 3 copies of 8-byte header)
if HAS_REEDSOLO and len(all_bits) >= RS_LENGTH_PREFIX_SIZE * 8: if HAS_REEDSOLO and len(all_bits) >= RS_LENGTH_PREFIX_SIZE * 8:
# Extract length prefix (24 bytes: 3 copies of 8-byte header for majority voting) # Extract length prefix (24 bytes: 3 copies of 8-byte header for majority voting)
@@ -1312,6 +1331,7 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
# Extract data # Extract data
data = raw_payload[HEADER_SIZE : HEADER_SIZE + data_length] data = raw_payload[HEADER_SIZE : HEADER_SIZE + data_length]
_write_progress(progress_file, 100, 100, "complete")
return data return data
except (ValueError, struct.error): except (ValueError, struct.error):
pass # Fall through to legacy format pass # Fall through to legacy format
@@ -1327,13 +1347,20 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
] ]
) )
_write_progress(progress_file, 100, 100, "complete")
return data return data
def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes: def _extract_jpegio(
stego_image: bytes,
seed: bytes,
progress_file: str | None = None,
) -> bytes:
"""Extract using jpegio for JPEG images.""" """Extract using jpegio for JPEG images."""
import os import os
_write_progress(progress_file, 0, 100, "loading")
# Normalize JPEG to avoid crashes with quality=100 images # Normalize JPEG to avoid crashes with quality=100 images
# (shouldn't happen with stego images, but be defensive) # (shouldn't happen with stego images, but be defensive)
stego_image = _normalize_jpeg_for_jpegio(stego_image) stego_image = _normalize_jpeg_for_jpegio(stego_image)
@@ -1347,6 +1374,8 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
all_positions = _jpegio_get_usable_positions(coef_array) all_positions = _jpegio_get_usable_positions(coef_array)
order = _jpegio_generate_order(len(all_positions), seed) order = _jpegio_generate_order(len(all_positions), seed)
_write_progress(progress_file, 20, 100, "extracting")
# Try RS-protected format first (has 24-byte length prefix: 3 copies for majority voting) # Try RS-protected format first (has 24-byte length prefix: 3 copies for majority voting)
if HAS_REEDSOLO and len(all_positions) >= RS_LENGTH_PREFIX_SIZE * 8: if HAS_REEDSOLO and len(all_positions) >= RS_LENGTH_PREFIX_SIZE * 8:
# Extract length prefix (24 bytes: 3 copies of 8-byte header) # Extract length prefix (24 bytes: 3 copies of 8-byte header)
@@ -1410,9 +1439,11 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
) )
try: try:
_write_progress(progress_file, 70, 100, "decoding")
raw_payload = _rs_decode(rs_encoded) raw_payload = _rs_decode(rs_encoded)
_, flags, data_length = _jpegio_parse_header(raw_payload[:HEADER_SIZE]) _, flags, data_length = _jpegio_parse_header(raw_payload[:HEADER_SIZE])
data = raw_payload[HEADER_SIZE : HEADER_SIZE + data_length] data = raw_payload[HEADER_SIZE : HEADER_SIZE + data_length]
_write_progress(progress_file, 100, 100, "complete")
return data return data
except (ValueError, struct.error): except (ValueError, struct.error):
pass # Fall through to legacy format pass # Fall through to legacy format
@@ -1450,6 +1481,7 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
] ]
) )
_write_progress(progress_file, 100, 100, "complete")
return data return data
finally: finally:

View File

@@ -33,6 +33,7 @@ def decode(
rsa_password: str | None = None, rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
channel_key: str | bool | None = None, channel_key: str | bool | None = None,
progress_file: str | None = None,
) -> DecodeResult: ) -> DecodeResult:
""" """
Decode a message or file from a stego image. Decode a message or file from a stego image.
@@ -45,6 +46,7 @@ def decode(
rsa_key_data: Optional RSA key bytes (if used during encoding) rsa_key_data: Optional RSA key bytes (if used during encoding)
rsa_password: Optional RSA key password rsa_password: Optional RSA key password
embed_mode: 'auto' (default), 'lsb', or 'dct' embed_mode: 'auto' (default), 'lsb', or 'dct'
progress_file: Optional path to write progress JSON for UI polling
channel_key: Channel key for deployment/group isolation: channel_key: Channel key for deployment/group isolation:
- None or "auto": Use server's configured key - None or "auto": Use server's configured key
- str: Use this specific channel key - str: Use this specific channel key
@@ -101,6 +103,7 @@ def decode(
stego_image, stego_image,
pixel_key, pixel_key,
embed_mode=embed_mode, embed_mode=embed_mode,
progress_file=progress_file,
) )
if not encrypted: if not encrypted:
@@ -126,6 +129,7 @@ def decode_file(
rsa_password: str | None = None, rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
channel_key: str | bool | None = None, channel_key: str | bool | None = None,
progress_file: str | None = None,
) -> Path: ) -> Path:
""" """
Decode a file from a stego image and save it. Decode a file from a stego image and save it.
@@ -140,6 +144,7 @@ def decode_file(
rsa_password: Optional RSA key password rsa_password: Optional RSA key password
embed_mode: 'auto', 'lsb', or 'dct' embed_mode: 'auto', 'lsb', or 'dct'
channel_key: Channel key parameter (see decode()) channel_key: Channel key parameter (see decode())
progress_file: Optional path to write progress JSON for UI polling
Returns: Returns:
Path where file was saved Path where file was saved
@@ -156,6 +161,7 @@ def decode_file(
rsa_password, rsa_password,
embed_mode, embed_mode,
channel_key, channel_key,
progress_file,
) )
if not result.is_file: if not result.is_file:
@@ -184,6 +190,7 @@ def decode_text(
rsa_password: str | None = None, rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
channel_key: str | bool | None = None, channel_key: str | bool | None = None,
progress_file: str | None = None,
) -> str: ) -> str:
""" """
Decode a text message from a stego image. Decode a text message from a stego image.
@@ -199,6 +206,7 @@ def decode_text(
rsa_password: Optional RSA key password rsa_password: Optional RSA key password
embed_mode: 'auto', 'lsb', or 'dct' embed_mode: 'auto', 'lsb', or 'dct'
channel_key: Channel key parameter (see decode()) channel_key: Channel key parameter (see decode())
progress_file: Optional path to write progress JSON for UI polling
Returns: Returns:
Decoded message string Decoded message string
@@ -215,6 +223,7 @@ def decode_text(
rsa_password, rsa_password,
embed_mode, embed_mode,
channel_key, channel_key,
progress_file,
) )
if result.is_file: if result.is_file:

View File

@@ -839,6 +839,7 @@ def extract_from_image(
pixel_key: bytes, pixel_key: bytes,
bits_per_channel: int = 1, bits_per_channel: int = 1,
embed_mode: str = EMBED_MODE_AUTO, embed_mode: str = EMBED_MODE_AUTO,
progress_file: str | None = None,
) -> bytes | None: ) -> bytes | None:
""" """
Extract hidden data from a stego image. Extract hidden data from a stego image.
@@ -848,6 +849,7 @@ def extract_from_image(
pixel_key: Key for pixel/coefficient selection (must match encoding) pixel_key: Key for pixel/coefficient selection (must match encoding)
bits_per_channel: Bits per channel (LSB mode only) bits_per_channel: Bits per channel (LSB mode only)
embed_mode: 'auto' (try both), 'lsb', or 'dct' embed_mode: 'auto' (try both), 'lsb', or 'dct'
progress_file: Optional path to write progress JSON for UI polling
Returns: Returns:
Extracted data bytes, or None if extraction fails Extracted data bytes, or None if extraction fails
@@ -863,7 +865,7 @@ def extract_from_image(
if has_dct_support(): if has_dct_support():
debug.print("Auto-detect: LSB failed, trying DCT") debug.print("Auto-detect: LSB failed, trying DCT")
result = _extract_dct(image_data, pixel_key) result = _extract_dct(image_data, pixel_key, progress_file)
if result is not None: if result is not None:
debug.print("Auto-detect: DCT extraction succeeded") debug.print("Auto-detect: DCT extraction succeeded")
return result return result
@@ -875,18 +877,22 @@ def extract_from_image(
elif embed_mode == EMBED_MODE_DCT: elif embed_mode == EMBED_MODE_DCT:
if not has_dct_support(): if not has_dct_support():
raise ImportError("scipy required for DCT mode") raise ImportError("scipy required for DCT mode")
return _extract_dct(image_data, pixel_key) return _extract_dct(image_data, pixel_key, progress_file)
# EXPLICIT LSB MODE # EXPLICIT LSB MODE
else: else:
return _extract_lsb(image_data, pixel_key, bits_per_channel) return _extract_lsb(image_data, pixel_key, bits_per_channel)
def _extract_dct(image_data: bytes, pixel_key: bytes) -> bytes | None: def _extract_dct(
image_data: bytes,
pixel_key: bytes,
progress_file: str | None = None,
) -> bytes | None:
"""Extract using DCT mode.""" """Extract using DCT mode."""
try: try:
dct_mod = _get_dct_module() dct_mod = _get_dct_module()
return dct_mod.extract_from_dct(image_data, pixel_key) return dct_mod.extract_from_dct(image_data, pixel_key, progress_file)
except Exception as e: except Exception as e:
debug.print(f"DCT extraction failed: {e}") debug.print(f"DCT extraction failed: {e}")
return None return None