Add progress_file support to DCT extraction

- Added progress_file parameter to extract_from_dct, _extract_scipy_dct_safe, _extract_jpegio
- Progress writes at key phases: loading, extracting, decoding, complete
- Updated extract_from_image and _extract_dct to pass through progress_file
- Updated decode(), decode_file(), decode_text() with progress_file param
- Progress JSON format: {current, total, percent, phase}

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee
2026-01-09 22:01:20 -05:00
parent e9e4d1aab9
commit c0fe85ac83
3 changed files with 56 additions and 9 deletions

View File

@@ -1157,7 +1157,11 @@ def _embed_jpegio(
pass
def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes:
def extract_from_dct(
stego_image: bytes,
seed: bytes,
progress_file: str | None = None,
) -> bytes:
"""Extract data from DCT stego image."""
img = Image.open(io.BytesIO(stego_image))
fmt = img.format
@@ -1165,16 +1169,22 @@ def extract_from_dct(stego_image: bytes, seed: bytes) -> bytes:
if fmt == "JPEG" and HAS_JPEGIO:
try:
return _extract_jpegio(stego_image, seed)
return _extract_jpegio(stego_image, seed, progress_file)
except ValueError:
pass
_check_scipy()
return _extract_scipy_dct_safe(stego_image, seed)
return _extract_scipy_dct_safe(stego_image, seed, progress_file)
def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
def _extract_scipy_dct_safe(
stego_image: bytes,
seed: bytes,
progress_file: str | None = None,
) -> bytes:
"""Extract using safe DCT operations with vectorized processing."""
_write_progress(progress_file, 0, 100, "loading")
img = Image.open(io.BytesIO(stego_image))
width, height = img.size
mode = img.mode
@@ -1207,6 +1217,9 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
embed_rows = np.array([pos[0] for pos in DEFAULT_EMBED_POSITIONS])
embed_cols = np.array([pos[1] for pos in DEFAULT_EMBED_POSITIONS])
# Progress reporting interval
PROGRESS_INTERVAL = 2000 # Report every N blocks
block_idx = 0
while block_idx < len(block_order):
# Determine batch size (may be smaller at end)
@@ -1236,6 +1249,10 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
del blocks, dct_blocks, coeffs, quantized
block_idx = batch_end
# Report progress
if progress_file and block_idx % PROGRESS_INTERVAL < BATCH_SIZE:
_write_progress(progress_file, block_idx, num_blocks, "extracting")
# Check if we have enough bits (early exit)
if len(all_bits) >= HEADER_SIZE * 8:
try:
@@ -1249,6 +1266,8 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
del padded
gc.collect()
_write_progress(progress_file, 80, 100, "decoding")
# Try RS-protected format first (has 24-byte length prefix: 3 copies of 8-byte header)
if HAS_REEDSOLO and len(all_bits) >= RS_LENGTH_PREFIX_SIZE * 8:
# Extract length prefix (24 bytes: 3 copies of 8-byte header for majority voting)
@@ -1312,6 +1331,7 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
# Extract data
data = raw_payload[HEADER_SIZE : HEADER_SIZE + data_length]
_write_progress(progress_file, 100, 100, "complete")
return data
except (ValueError, struct.error):
pass # Fall through to legacy format
@@ -1327,13 +1347,20 @@ def _extract_scipy_dct_safe(stego_image: bytes, seed: bytes) -> bytes:
]
)
_write_progress(progress_file, 100, 100, "complete")
return data
def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
def _extract_jpegio(
stego_image: bytes,
seed: bytes,
progress_file: str | None = None,
) -> bytes:
"""Extract using jpegio for JPEG images."""
import os
_write_progress(progress_file, 0, 100, "loading")
# Normalize JPEG to avoid crashes with quality=100 images
# (shouldn't happen with stego images, but be defensive)
stego_image = _normalize_jpeg_for_jpegio(stego_image)
@@ -1347,6 +1374,8 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
all_positions = _jpegio_get_usable_positions(coef_array)
order = _jpegio_generate_order(len(all_positions), seed)
_write_progress(progress_file, 20, 100, "extracting")
# Try RS-protected format first (has 24-byte length prefix: 3 copies for majority voting)
if HAS_REEDSOLO and len(all_positions) >= RS_LENGTH_PREFIX_SIZE * 8:
# Extract length prefix (24 bytes: 3 copies of 8-byte header)
@@ -1410,9 +1439,11 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
)
try:
_write_progress(progress_file, 70, 100, "decoding")
raw_payload = _rs_decode(rs_encoded)
_, flags, data_length = _jpegio_parse_header(raw_payload[:HEADER_SIZE])
data = raw_payload[HEADER_SIZE : HEADER_SIZE + data_length]
_write_progress(progress_file, 100, 100, "complete")
return data
except (ValueError, struct.error):
pass # Fall through to legacy format
@@ -1450,6 +1481,7 @@ def _extract_jpegio(stego_image: bytes, seed: bytes) -> bytes:
]
)
_write_progress(progress_file, 100, 100, "complete")
return data
finally:

View File

@@ -33,6 +33,7 @@ def decode(
rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO,
channel_key: str | bool | None = None,
progress_file: str | None = None,
) -> DecodeResult:
"""
Decode a message or file from a stego image.
@@ -45,6 +46,7 @@ def decode(
rsa_key_data: Optional RSA key bytes (if used during encoding)
rsa_password: Optional RSA key password
embed_mode: 'auto' (default), 'lsb', or 'dct'
progress_file: Optional path to write progress JSON for UI polling
channel_key: Channel key for deployment/group isolation:
- None or "auto": Use server's configured key
- str: Use this specific channel key
@@ -101,6 +103,7 @@ def decode(
stego_image,
pixel_key,
embed_mode=embed_mode,
progress_file=progress_file,
)
if not encrypted:
@@ -126,6 +129,7 @@ def decode_file(
rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO,
channel_key: str | bool | None = None,
progress_file: str | None = None,
) -> Path:
"""
Decode a file from a stego image and save it.
@@ -140,6 +144,7 @@ def decode_file(
rsa_password: Optional RSA key password
embed_mode: 'auto', 'lsb', or 'dct'
channel_key: Channel key parameter (see decode())
progress_file: Optional path to write progress JSON for UI polling
Returns:
Path where file was saved
@@ -156,6 +161,7 @@ def decode_file(
rsa_password,
embed_mode,
channel_key,
progress_file,
)
if not result.is_file:
@@ -184,6 +190,7 @@ def decode_text(
rsa_password: str | None = None,
embed_mode: str = EMBED_MODE_AUTO,
channel_key: str | bool | None = None,
progress_file: str | None = None,
) -> str:
"""
Decode a text message from a stego image.
@@ -199,6 +206,7 @@ def decode_text(
rsa_password: Optional RSA key password
embed_mode: 'auto', 'lsb', or 'dct'
channel_key: Channel key parameter (see decode())
progress_file: Optional path to write progress JSON for UI polling
Returns:
Decoded message string
@@ -215,6 +223,7 @@ def decode_text(
rsa_password,
embed_mode,
channel_key,
progress_file,
)
if result.is_file:

View File

@@ -839,6 +839,7 @@ def extract_from_image(
pixel_key: bytes,
bits_per_channel: int = 1,
embed_mode: str = EMBED_MODE_AUTO,
progress_file: str | None = None,
) -> bytes | None:
"""
Extract hidden data from a stego image.
@@ -848,6 +849,7 @@ def extract_from_image(
pixel_key: Key for pixel/coefficient selection (must match encoding)
bits_per_channel: Bits per channel (LSB mode only)
embed_mode: 'auto' (try both), 'lsb', or 'dct'
progress_file: Optional path to write progress JSON for UI polling
Returns:
Extracted data bytes, or None if extraction fails
@@ -863,7 +865,7 @@ def extract_from_image(
if has_dct_support():
debug.print("Auto-detect: LSB failed, trying DCT")
result = _extract_dct(image_data, pixel_key)
result = _extract_dct(image_data, pixel_key, progress_file)
if result is not None:
debug.print("Auto-detect: DCT extraction succeeded")
return result
@@ -875,18 +877,22 @@ def extract_from_image(
elif embed_mode == EMBED_MODE_DCT:
if not has_dct_support():
raise ImportError("scipy required for DCT mode")
return _extract_dct(image_data, pixel_key)
return _extract_dct(image_data, pixel_key, progress_file)
# EXPLICIT LSB MODE
else:
return _extract_lsb(image_data, pixel_key, bits_per_channel)
def _extract_dct(image_data: bytes, pixel_key: bytes) -> bytes | None:
def _extract_dct(
image_data: bytes,
pixel_key: bytes,
progress_file: str | None = None,
) -> bytes | None:
"""Extract using DCT mode."""
try:
dct_mod = _get_dct_module()
return dct_mod.extract_from_dct(image_data, pixel_key)
return dct_mod.extract_from_dct(image_data, pixel_key, progress_file)
except Exception as e:
debug.print(f"DCT extraction failed: {e}")
return None